Unverified Commit 734d8991 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files
parent b652395e
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff_val);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at::Tensor lamb(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay)
{
CHECK_INPUT(p);
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(
p_copy.numel() == num_elem || p_copy.numel() == 0,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at::Tensor w_l2_i = at::empty(
{512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected
at::Tensor u_l2_i = at::empty(
{512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
at::Tensor lamb_coeff_val = at::empty(
{1},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
: p.type().scalarType()));
fused_lamb_cuda(p,
p_copy,
m,
v,
g,
lr,
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step,
mode,
bias_correction,
decay,
w_l2_i,
u_l2_i,
lamb_coeff_val);
return lamb_coeff_val;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("lamb", &lamb, "Adam optimized CUDA implementation with LAMB.");
}
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#include <cuda_runtime_api.h>
#include <cooperative_groups.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
......@@ -23,54 +23,49 @@ namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
// Ensure that we won't compile any un-specialized types
__device__ inline operator T *()
{
extern __device__ void error(void);
error();
return NULL;
}
};
extern __device__ void error(void);
error();
return NULL;
}
};
template <>
struct SharedMemory <float>
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
__device__ inline operator float *()
{
extern __shared__ float s_float[];
return s_float;
}
};
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory <double>
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
__device__ inline operator double *()
{
extern __shared__ double s_double[];
return s_double;
}
};
extern __shared__ double s_double[];
return s_double;
}
};
} // namespace
#include "type_shim.h"
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
//s_a and s_b are in shared memory
//g_a and g_b are in shared memory
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void
reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
......@@ -84,104 +79,82 @@ reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256))
{
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128))
{
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64))
{
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300 )
if ( tid < 32 )
{
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64)
{
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize/2; offset > 0; offset /= 2)
{
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32))
{
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16))
{
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8))
{
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4))
{
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2))
{
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1))
{
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
......@@ -189,230 +162,217 @@ reduce_block_in_shared_memory(T *s_a, T *s_b, T* g_a, T* g_b)
#endif
// write result for this block to global mem
if (tid == 0){
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b){
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T *s_a = SharedMemory<T>();
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b ,g_a, g_b);
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
T pj = p[j];
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j]/denom) + (decay*p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T,blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(
const size_t tsize,
T* __restrict__ g_a,
T* __restrict__ g_b)
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T *s_a = SharedMemory<T>() ;
T *s_b = SharedMemory<T>() + cg::this_thread_block().size();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize){
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T,blockSize>(s_a, s_b, g_a, g_b);
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w !=0 and reg_u !=0){
lamb_coeff = reg_w/reg_u;
if (lamb_coeff > max_coeff){
lamb_coeff = max_coeff;
}
if (lamb_coeff < min_coeff){
lamb_coeff = min_coeff;
}
}
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
if(blockId == 0 and threadIdInBlock == 0)
{
lamb_coeff_val[0] = lamb_coeff;
//printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
float lamb_coeff = 1.0;
for (int j = i; j < tsize; j+=totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj/denom) + (decay*pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T) pj;
if (reg_w != 0 and reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 and threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor & w_l2_i,
at::Tensor & u_l2_i,
at::Tensor & lamb_coeff)
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize+threadsPerBlock-1)/threadsPerBlock;
if (num_blocks > 512) num_blocks=512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.scalar_type(), "lamb_cuda_kernel", ([&] {
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>
<<<blocks, threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
......@@ -424,17 +384,17 @@ void fused_lamb_cuda(
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>
<<<1, threadsPerBlock, smemsize, stream>>>(
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
lamb_cuda_kernel_part3<accscalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
lamb_cuda_kernel_part3<accscalar_t, scalar_t>
<<<blocks, threadsPerBlock, smemsize, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
......@@ -448,20 +408,20 @@ void fused_lamb_cuda(
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "lamb_cuda_kernel", ([&] {
lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock><<<blocks,threadsPerBlock, smemsize, stream>>>(
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>
<<<blocks, threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
......@@ -471,20 +431,19 @@ void fused_lamb_cuda(
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>
<<<1, threadsPerBlock, smemsize, stream>>>(
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock><<<1,threadsPerBlock, smemsize, stream>>>(
num_blocks,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
lamb_cuda_kernel_part3<scalar_t, scalar_t><<<blocks,threadsPerBlock, smemsize, stream>>>(
lamb_cuda_kernel_part3<scalar_t, scalar_t>
<<<blocks, threadsPerBlock, smemsize, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
......@@ -496,16 +455,15 @@ void fused_lamb_cuda(
grad_scale,
step_size,
tsize,
(adamMode_t) mode,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
THCudaCheck(cudaGetLastError());
}
THCudaCheck(cudaGetLastError());
}
//template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a, float* g_b, cg::grid_group &cgg);
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
#include "cublas_wrappers.h"
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_32F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_32F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
C,
CUDA_R_32F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_16F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
(void*)C,
CUDA_R_16F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_32F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_16F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
return EXIT_FAILURE;
}
return 0;
}
#include "custom_cuda_layers.h"
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
float4 rand = curand_uniform4(&state);
uint8_t m[4];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * 4;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
int i = j * 4;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[4];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
int i = j * 4;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint8_t* m = reinterpret_cast<uint8_t*>(mask_cast + j);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
int i = j * 4;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
cudaStream_t stream,
bool bwd)
{
dim3 grid_dim = DS_GET_BLOCKS(total_count / 4);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, vals, out, mask, seed);
else
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
cudaStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
cudaStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / 4)
{
uint8_t* m = reinterpret_cast<uint8_t*>(mask_cast + j);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < 4; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
float2 x_data = x_cast[j];
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
x_cast[j] = result_f;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / 2)
{
int i = j * 2;
Xdata[i] = (__half)((float)Xdata[i] * scale * mask[i]);
Xdata[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]);
}
#endif
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
{
const float scale = 1. / (1. - ratio);
dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / 2), DS_CUDA_NUM_THREADS, 0, stream>>>(
total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
cudaStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
cudaStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(j, N / 2)
{
int i = j * 2;
out[i] = (__half)((float)Xdata[i] * scale * mask[i]);
out[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]);
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
cudaStream_t stream)
{
const float scale = 1. / (1. - ratio);
dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / 2), DS_CUDA_NUM_THREADS, 0, stream>>>(
total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
cudaStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
cudaStream_t stream);
__global__ void dropout_kernel(const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
{
float4 rand = curand_uniform4(&state);
uint8_t m[4];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = blockIdx.x * dim + tid * 4;
float4 x_data = Xdata_cast[idx];
float4 b_data = bias_cast[tid];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
Xdata_cast[idx] = x_data;
}
}
__global__ void dropout_kernel(const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
{
int i = blockIdx.x * dim + tid * 4;
float4 rand = curand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[idx];
bias_f = bias_cast[tid];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint8_t m[4]; // = mask + i;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[idx] = result_f;
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream)
{
dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4);
dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream);
__global__ void dropout_kernel(const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
{
float4 rand = curand_uniform4(&state);
uint8_t m[4];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
// int bid = k * blockDim.x + tid;
int i = blockIdx.x * dim + tid * 4;
float4 out_data = out_cast[idx];
float4 b_data = bias_cast[tid];
float4 res_data = residual_cast[idx];
float4 inp_data = input_cast[idx];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
out_cast[idx] = out_data;
}
}
__global__ void dropout_kernel(const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
{
int i = blockIdx.x * dim + tid * 4;
float4 rand = curand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
data_f = out_cast[idx];
bias_f = bias_cast[tid];
residual_f = residual_cast[idx];
input_f = input_cast[idx];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint8_t m[4]; // = mask + i;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[idx] = result_f;
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream)
{
dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4);
dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
cudaStream_t stream);
#include <torch/extension.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
// C++ interface
template <typename T>
size_t get_workspace_size(int maxBatchSize,
int seq_len,
int hidden_size,
int heads,
bool training,
bool gelu_checkpoint)
{
size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += (std::max((4 * size_t(maxBatchSize) * seq_len * hidden_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
return workSpacesize * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
int batch_size,
int hidden_size,
int num_heads,
int intermediate_size,
int seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
: _layer_id(layer_id),
_batch_size(batch_size),
_hidden_size(hidden_size),
_heads(num_heads),
_intermediate_size(intermediate_size),
_seq_length(seq_length),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_attn_dropout_checkpoint(attn_dropout_checkpoint),
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
gemm_algos[0])),
_attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
hidden_size,
gemm_algos[0])),
_norm_layer2(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_norm_layer3(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
4 * hidden_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
4 * hidden_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, _intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio,
_batch_size * _heads * _seq_length,
_seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
_hidden_size / _heads,
(T(1.0) / T(sqrt(_hidden_size / _heads))),
T(0.0),
CUBLAS_OP_T,
CUBLAS_OP_N,
gemm_algos[3])),
_attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_hidden_size / _heads,
_seq_length,
_seq_length,
T(1.0),
T(0.0),
CUBLAS_OP_N,
CUBLAS_OP_N,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
assert(_seq_length <= 1024);
Initialize();
}
template <typename T>
BertTransformerLayer<T>::~BertTransformerLayer()
{
}
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
Context::Instance().GenWorkSpace(get_workspace_size<T>(
_batch_size, _seq_length, _hidden_size, _heads, _training, _gelu_checkpoint));
if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
template <typename T>
void BertTransformerLayer<T>::Forward(int bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_qkvb_ptr,
const T* attn_ow_ptr,
const T* attn_ob_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* output_b_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* out_ptr,
T* inp_norm_ptr,
T* q_tf_ptr,
T* k_tf_ptr,
T* v_tf_ptr,
T* soft_out_ptr,
T* ctx_bufB_ptr,
T* attn_o_inp_ptr,
T* add_res_ptr,
T* ff1_inp_ptr,
T* gelu_inp_ptr,
T* ff2_inp_ptr)
{
cublasSetStream(_cublasHandle, _stream);
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
_qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
launch_bias_add_transform_0213<T>(
q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
int bsz_heads = bsz * _heads;
// attention scores
_attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
// Softmax + Mask
_softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
// attn prob dropout.
_attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
// attention context
_attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
launch_transform4d_0213<T>(
attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
if (_pre_or_postLayerNorm)
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
else
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
// attn output dropout.
if (_pre_or_postLayerNorm)
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
else
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
ff1_inp_ptr,
inter_w_ptr,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
_stream);
_ff2.Forward(bsz_seq,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
output_w_ptr,
out_ptr,
_cublasHandle);
// layer output dropout.
if (_pre_or_postLayerNorm)
_layer_output_dropout.ForwardWithBias(
bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
else
_layer_output_dropout.ForwardWithBias(
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
template <typename T>
void BertTransformerLayer<T>::Backward(int bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
const T* inp_norm_ptr,
const T* q_tf_ptr,
const T* k_tf_ptr,
const T* v_tf_ptr,
const T* soft_out_ptr,
const T* ctx_bufB_ptr,
const T* attn_o_inp_ptr,
const T* add_res_ptr,
const T* ff1_inp_ptr,
const T* gelu_inp_ptr,
const T* ff2_inp_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_ow_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* grad_input_ptr,
T* grad_attn_qkvw_ptr,
T* grad_attn_qkvb_ptr,
T* grad_attn_ow_ptr,
T* grad_attn_ob_ptr,
T* grad_attn_nw_ptr,
T* grad_attn_nb_ptr,
T* grad_inter_w_ptr,
T* grad_inter_b_ptr,
T* grad_output_w_ptr,
T* grad_output_b_ptr,
T* grad_norm_w_ptr,
T* grad_norm_b_ptr)
{
cublasSetStream(_cublasHandle, _stream);
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
T* ff2_buf = buf_3 + (_gelu_checkpoint ? 3 : 1) * small_buf_size;
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
cudaStream_t streams[2] = {_stream, _stream};
int bsz_seq = bsz * _seq_length;
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.Backward(bsz,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
else
_norm_layer3.Backward(bsz,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
}
if (_pre_or_postLayerNorm)
_layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
else
_layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
output_w_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
_cublasHandle,
_stream,
ff2_buf);
_gelu.Backward(
bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
ff1_inp_ptr,
inter_w_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
_cublasHandle,
_stream,
buf_3);
if (!_pre_or_postLayerNorm)
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.BackwardFusedAdd(bsz,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_norm_layer2.BackwardFusedAdd(bsz,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.Backward(bsz,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_norm_layer2.Backward(bsz,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
_attn_out_linear.Backward(bsz_seq,
attn_output_dropout_buf,
attn_o_inp_ptr,
attn_ow_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
_cublasHandle,
_stream,
buf_1);
launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
if (_attn_prob_dropout.HasDropout()) {
if (_attn_dropout_checkpoint)
_attn_prob_dropout.Forward(
bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
_attn_context.Backward(bsz_heads,
buf_2,
v_tf_ptr,
(_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
_cublasHandle,
buf_3,
ff2_buf);
} else
_attn_context.Backward(
bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
_attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
_softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
_attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
if (_pre_or_postLayerNorm)
_qkv_linear.Backward(bsz_seq,
ff2_buf,
inp_norm_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
else
_qkv_linear.Backward(bsz_seq,
ff2_buf,
input_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.BackwardFusedAdd(bsz,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
else
_norm_layer3.BackwardFusedAdd(bsz,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
} else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
template <typename T>
void BertTransformerLayer<T>::SetTrainingMode(bool training)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_output_dropout.SetTrainingMode(training);
_layer_output_dropout.SetTrainingMode(training);
}
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
}
template <typename T>
int create_transformer_layer(int layer_id,
int batch_size,
int hidden_dim,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
bool pre_or_postLayerNorm,
bool test_gemm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
<< std::endl;
return 0;
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b,
bool training_mode,
bool prelayernorm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint)
{
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
int bsz = input.size(0);
const T* input_ptr = (const T*)input.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* output_b_ptr = (const T*)output_b.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
auto output = torch::empty_like(input);
T* out_ptr = (T*)output.data_ptr();
auto options = torch::TensorOptions()
.dtype(input.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto uint8_options = torch::TensorOptions()
.dtype(torch::kInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(false);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp =
torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint
? ff2_inp
: torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out = torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
layer->Forward(bsz,
input_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_qkvb_ptr,
attn_ow_ptr,
attn_ob_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
output_b_ptr,
norm_w_ptr,
norm_b_ptr,
out_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr);
return {output,
inp_norm,
qkv_tf,
soft_out,
ctx_bufB,
attn_o_inp,
add_res,
ff1_inp,
gelu_inp,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask};
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
const torch::Tensor& grad_output,
const torch::Tensor& output,
const torch::Tensor& inp_norm,
const torch::Tensor& qkv_tf,
const torch::Tensor& soft_out,
const torch::Tensor& ctx_bufB,
const torch::Tensor& attn_o_inp,
const torch::Tensor& add_res,
const torch::Tensor& ff1_inp,
const torch::Tensor& gelu_inp,
const torch::Tensor& ff2_inp,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b)
{
auto g_output = grad_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(inp_norm);
CHECK_INPUT(qkv_tf);
CHECK_INPUT(add_res);
CHECK_INPUT(soft_out);
CHECK_INPUT(ctx_bufB);
CHECK_INPUT(attn_o_inp);
CHECK_INPUT(ff1_inp);
CHECK_INPUT(gelu_inp);
CHECK_INPUT(ff2_inp);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
int bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
auto grad_attn_ow = torch::empty_like(attn_ow);
auto grad_attn_ob = torch::empty_like(attn_ob);
auto grad_attn_nw = torch::empty_like(attn_nw);
auto grad_attn_nb = torch::empty_like(attn_nb);
auto grad_inter_w = torch::empty_like(inter_w);
auto grad_inter_b = torch::empty_like(inter_b);
auto grad_output_w = torch::empty_like(output_w);
auto grad_output_b = torch::empty_like(output_b);
auto grad_norm_w = torch::empty_like(norm_w);
auto grad_norm_b = torch::empty_like(norm_b);
// inputs.
const T* grad_output_ptr = (const T*)g_output.data_ptr();
const T* input_ptr = (const T*)input.data_ptr();
const T* output_ptr = (const T*)output.data_ptr();
const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
const T* add_res_ptr = (const T*)add_res.data_ptr();
const T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
const T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
const T* soft_out_ptr = (const T*)soft_out.data_ptr();
const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
// outputs.
T* grad_input_ptr = (T*)grad_input.data_ptr();
T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
input_ptr,
output_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_ow_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
norm_w_ptr,
norm_b_ptr,
grad_input_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr);
return {grad_input,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward_fp32",
&ds_transformer_forward<float>,
"DeepSpeed Transformer forward with fp32 (CUDA)");
m.def("forward_fp16",
&ds_transformer_forward<__half>,
"DeepSpeed Transformer forward with fp16 (CUDA)");
m.def("backward_fp32",
&ds_transformer_backward<float>,
"DeepSpeed Transformer backward with fp32 (CUDA)");
m.def("backward_fp16",
&ds_transformer_backward<__half>,
"DeepSpeed Transformer backward with fp16 (CUDA)");
m.def("create_transformer_layer_fp32",
&create_transformer_layer<float>,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
m.def("create_transformer_layer_fp16",
&create_transformer_layer<__half>,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
}
#include "custom_cuda_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size);
}
template void
launch_bias_gelu<float>(const float*, const float*, float*, int, int, int, cudaStream_t);
template void
launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, int, cudaStream_t);
template void
launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
#include "general_kernels.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
cudaStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
__global__ void fused_add2_kernel(float* out,
const float* inp1,
const float* inp2,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add2_kernel(__half* out,
const __half* inp1,
const __half* inp2,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
inp1_4 = inp1_arr[row * row_stride + id];
inp2_4 = inp2_arr[row * row_stride + id];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
Unlike the GELU, which doesn't require template parameters, this layer does since it
does rely fairly heavily on unrolling loops. Currently, I exclude bounds checks and
assume that the number of elements is a multiple of a power of 2. Default behavior
for our purposes uses 256 threads for floats, and 128 threads for __half. This restriction
is a result of using the shift parameter to perform the minimum number of register file
shuffles necessary, which requires the number of threads in the secondary reduction to
be 1, 2, 4, 8, 16, or 32. The number of threads here corresponds to the number of complete
warps in the threadblock.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
float* vars = nullptr,
float* means = nullptr,
float* vals_hat = nullptr)
{
constexpr int iteration_stride = row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[iterations];
__shared__ float shr[iteration_stride >> 5];
float sum = 0.f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[row * row_stride + i * iteration_stride + id];
sum += vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
variance += (vals_arr[i] - mean) * (vals_arr[i] - mean);
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[row * row_stride + i * iteration_stride + id] = vals_arr[i];
}
}
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
__half* vars = nullptr,
__half* means = nullptr,
__half* vals_hat = nullptr)
{
#if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
__half2 vals_arr[iterations];
float2 vals_f[iterations];
__shared__ float shr[iteration_stride >> 5];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
float sum = 0.f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
variance += (vals_f[i].x - mean) * (vals_f[i].x - mean);
variance += (vals_f[i].y - mean) * (vals_f[i].y - mean);
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
for (int i = 0; i < iterations; i++) {
vals_arr[i] = __float22half2_rn(vals_f[i]);
vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h);
vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] +
beta_cast[i * iteration_stride + id];
vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i];
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means,
T* vals_hat);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means,
float* vals_hat)
{
constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
__half* vals_hat)
{
constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
float* vars = nullptr,
float* vals_hat = nullptr,
bool save_vals = false)
{
constexpr int iteration_stride = row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[iterations];
__shared__ float shr[iteration_stride >> 5];
float sum = 0.f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[row * row_stride + i * iteration_stride + id];
sum += vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
variance += (vals_arr[i] - mean) * (vals_arr[i] - mean);
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[row * row_stride + i * iteration_stride + id] = vals_arr[i];
}
}
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
__half* vars = nullptr,
__half* vals_hat = nullptr,
bool save_vals = false)
{
#if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
__half2 vals_arr[iterations];
float2 vals_f[iterations];
__shared__ float shr[iteration_stride >> 5];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
float sum = 0.f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
variance += (vals_f[i].x - mean) * (vals_f[i].x - mean);
variance += (vals_f[i].y - mean) * (vals_f[i].y - mean);
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
for (int i = 0; i < iterations; i++) {
vals_arr[i] = __float22half2_rn(vals_f[i]);
vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h);
vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] +
beta_cast[i * iteration_stride + id];
vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i];
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* vals_hat,
bool save_vals);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* vals_hat,
bool save_vals)
{
constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* vals_hat,
bool save_vals)
{
constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
float betta_reg = (invertible ? (float)betta[pos] : 0.0f);
float gamma_reg = (float)gamma[pos];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible)
{
constexpr int iterations = row_stride / THREADS;
constexpr int iteration_stride = THREADS; // row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
float vals_arr[iterations];
float vals_hat_arr[iterations];
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat[row * row_stride + i * iteration_stride + id] -
betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[row * row_stride + i * iteration_stride + id]);
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] = (vals_arr[i] - sum);
}
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible)
{
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations;
constexpr int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num =
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations];
float2 vals_arr_f[iterations];
__half2 vals_hat_arr[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat_h[row * row_stride + i * iteration_stride + id] -
betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[row * row_stride + i * iteration_stride + id]);
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 512)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768)
LayerNormBackward2<384><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 512)
LayerNormBackward2<256><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad)
{
constexpr int iterations = row_stride / THREADS;
constexpr int iteration_stride = THREADS; // row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
float vals_arr[iterations];
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[iterations];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[row * row_stride + i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] = (vals_arr[i] - sum);
}
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad)
{
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations;
constexpr int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num =
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations];
float2 vals_arr_f[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[iterations];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_h[row * row_stride + i * iteration_stride + id] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 512)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768)
LayerNormBackward2<384><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 512)
LayerNormBackward2<256><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
float betta_reg = (invertible ? (float)betta[pos] : 0.0f);
float gamma_reg = (float)gamma[pos];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible)
{
constexpr int iterations = row_stride / THREADS;
constexpr int iteration_stride = THREADS; // row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
float vals_arr[iterations];
float vals_hat_arr[iterations];
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat[row * row_stride + i * iteration_stride + id] -
betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[row * row_stride + i * iteration_stride + id]);
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[row * row_stride + i * iteration_stride + id];
}
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible)
{
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations;
constexpr int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num =
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations];
float2 vals_arr_f[iterations];
__half2 vals_hat_arr[iterations];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = (invertible ? (vals_hat_h[row * row_stride + i * iteration_stride + id] -
betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[row * row_stride + i * iteration_stride + id]);
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] =
temp + out_grad_h2[row * row_stride + i * iteration_stride + id];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 512)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768)
LayerNormBackward2_fused_add<384><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 512)
LayerNormBackward2_fused_add<256><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad)
{
constexpr int iterations = row_stride / THREADS;
constexpr int iteration_stride = THREADS; // row_stride / iterations;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
float vals_arr[iterations];
float vals_hat_arr[iterations];
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[row * row_stride + i * iteration_stride + id];
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[iterations];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[row * row_stride + i * iteration_stride + id];
;
}
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad)
{
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations;
constexpr int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
constexpr int warp_num =
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations];
float2 vals_arr_f[iterations];
__half2 vals_hat_arr[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[row * row_stride + i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[row * row_stride + i * iteration_stride + id];
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[iterations];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] =
temp + out_grad_h2[row * row_stride + i * iteration_stride + id];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 512)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768)
LayerNormBackward2_fused_add<384><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 512)
LayerNormBackward2_fused_add<256><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
#include "custom_cuda_layers.h"
#include "general_kernels.h"
namespace cg = cooperative_groups;
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#if __CUDA_ARCH__ >= 700
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
cudaStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
attn_softmax<2, (threads / 2), 2>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
attn_softmax<4, (threads / 4), 4>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
attn_softmax<8, (threads / 8), 8>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
attn_softmax<16, (threads / 16), 16>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
attn_softmax<32, (threads / 32), 32>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
attn_softmax<32, (threads / 64), 64>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
cudaStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
attn_softmax<2, (threads / 2), 2>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
attn_softmax<4, (threads / 4), 4>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
attn_softmax<8, (threads / 8), 8>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
attn_softmax<16, (threads / 16), 16>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
attn_softmax<32, (threads / 32), 32>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
attn_softmax<32, (threads / 64), 64>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> 5;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream)
{
if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048)
throw std::runtime_error("Invalid sequence length found in softmax backward.");
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
switch (seq_length) {
case 32:
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 64:
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 128:
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 256:
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 384:
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 512:
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 768:
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 1024:
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 2048:
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
default:
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream);
#include "custom_cuda_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
cudaStream_t stream)
{
int threads = THREADS;
Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream)
{
int threads = THREADS;
Transpose_Kernel<float><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output, const T* vals, int hidden_dim, int seq_length, int heads);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads)
{
int d0_stride = hidden_dim * seq_length / 4;
int d1_stride = hidden_dim / 4;
int d2_stride = hidden_dim / heads / 4;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length / 8;
int d1_stride = hidden_dim / 8;
int d2_stride = hidden_dim / heads / 8;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
cudaStream_t stream)
{
dim3 block_dim(hidden_dim / heads / 4, heads);
dim3 grid_dim(batch_size, seq_length);
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
cudaStream_t stream)
{
dim3 block_dim(hidden_dim / heads / 8, heads);
dim3 grid_dim(batch_size, seq_length);
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads)
{
int d0_stride = hidden_dim * seq_length / 4;
int d1_stride = hidden_dim / 4;
int d2_stride = hidden_dim / heads / 4;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * gridDim.z + cnt * d1_stride +
d1 * d1_stride * gridDim.z + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8;
int d1_stride = hidden_dim / 8;
int d2_stride = hidden_dim / heads / 8;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
cudaStream_t stream,
int trans_count)
{
dim3 block_dim(hidden_dim / heads / 4, heads);
dim3 grid_dim(batch_size, seq_length, trans_count);
bias_add_transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
cudaStream_t stream,
int trans_count)
{
dim3 block_dim(hidden_dim / heads / 8, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
bias_add_transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads);
}
template <typename T>
__global__ void transform4d_0213(T* out, const T* in, int heads, int seq_length, int hidden_dim);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim)
{
int d0_stride = hidden_dim * seq_length / 4;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads / 4;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim / 4;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length + blockDim.y - 1) / blockDim.y); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8;
int d1_stride = hidden_dim / 8;
int d2_stride = hidden_dim / heads / 8;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + d1 % 2 * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
cudaStream_t stream,
int trans_count)
{
dim3 grid_dims(batch_size, heads * ((seq_length + 7) / 8), trans_count);
dim3 block_dims(hidden_dim / heads / 4, 8);
transform4d_0213<float>
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
cudaStream_t stream,
int trans_count)
{
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads / 8, heads, trans_count);
transform4d_0213<__half>
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim);
}
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
......@@ -5,6 +5,8 @@ Copyright 2020 The Microsoft DeepSpeed Team
from deepspeed.pt.deepspeed_light import DeepSpeedLight
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments
from deepspeed.pt.deepspeed_cuda import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
import deepspeed.pt.deepspeed_checkpointing as checkpointing
......
from torch import nn
from torch.autograd import Function
import torch
import json
import math
import deepspeed_transformer_cuda as ds_transformer_cuda
import deepspeed_stochastic_transformer_cuda as ds_stochastic_transformer_cuda
class TransformerConfig():
def __init__(self,
batch_size,
max_seq_length,
hidden_size,
heads,
attn_dropout_ratio,
hidden_dropout_ratio,
num_hidden_layers,
initializer_range):
self.layer_id = -1
self.batch_size = batch_size
self.hidden_size = hidden_size
self.max_seq_length = max_seq_length
self.heads = heads
self.attn_dropout_ratio = attn_dropout_ratio
self.hidden_dropout_ratio = hidden_dropout_ratio
self.num_hidden_layers = num_hidden_layers
self.initializer_range = initializer_range
class DeepSpeedTransformerConfig(TransformerConfig):
"""Initialize the DeepSpeed Transformer Config.
Arguments:
batch_size: The maximum batch size used for running the kernel on each GPU
max_seq_length: The sequence-length of the model being trained with DeepSpeed
hidden_size: The hidden size of the transformer layer
heads: The number of heads in the self-attention of the transformer layer
attn_dropout_ratio: The ratio of dropout for the attention's output
hidden_dropout_ratio: The ratio of dropout for the transformer's output
num_hidden_layers: The number of transformer layers
initializer_range: BERT model's initializer range for initializing parameter data
local_rank: Optional: The rank of GPU running the transformer kernel, it is not required
to use if the model already set the current device, otherwise need to set it
so that the transformer kernel can work on the right device
seed: The random seed for the dropout layers
fp16: Enable half-precision computation
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
normalize_invertible: Optional: Enable invertible LayerNorm execution (dropping the input activation),
default is False
gelu_checkpoint: Optional: Enable checkpointing of Gelu activation output to save memory,
default is False
adjust_init_range: Optional: Set as True (default) if the model adjusts the weight initial values of
its self-attention output and layer output, False keeps the initializer_range no change.
See the adjustment below:
output_std = self.config.initializer_range / math.sqrt(2.0 * num_layers)
attn_dropout_checkpoint: Optional: Enable checkpointing of attention dropout to save memory,
default is False
stochastic_mode: Enable for high performance, please note that this flag has some level of
non-determinism and can produce different results on different runs. However, we have seen
that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
to turn it off in order to be able to reproduce the same result through the regular kernel execution.
"""
def __init__(self,
batch_size=-1,
max_seq_length=-1,
hidden_size=-1,
heads=-1,
attn_dropout_ratio=-1,
hidden_dropout_ratio=-1,
num_hidden_layers=-1,
initializer_range=-1,
local_rank=-1,
seed=-1,
fp16=False,
pre_layer_norm=True,
normalize_invertible=False,
gelu_checkpoint=False,
adjust_init_range=True,
attn_dropout_checkpoint=False,
stochastic_mode=False):
super(DeepSpeedTransformerConfig,
self).__init__(batch_size,
max_seq_length,
hidden_size,
heads,
attn_dropout_ratio,
hidden_dropout_ratio,
num_hidden_layers,
initializer_range)
self.fp16 = fp16
self.pre_layer_norm = pre_layer_norm = True
self.local_rank = local_rank
self.seed = seed
self.normalize_invertible = normalize_invertible
self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
self.adjust_init_range = adjust_init_range
self.test_gemm = False
self.training = True
self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint
self.stochastic_mode = stochastic_mode
@classmethod
def from_dict(cls, json_object):
config = DeepSpeedTransformerConfig()
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
class DeepSpeedTransformerFunction(Function):
@staticmethod
def forward(ctx,
input,
input_mask,
self,
grads,
layer_id,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config):
bsz = input.shape[0]
if bsz > config.batch_size:
raise ValueError('Input batch size exceeds the limit.')
cuda_module = ds_stochastic_transformer_cuda if config.stochastic_mode else ds_transformer_cuda
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
(output,
inp_norm,
qkv_tf,
soft_inp,
ctx_bufB,
attn_o_inp,
add_res,
ff1_inp,
gelu_inp,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
# For testing only.
if grads is not None:
for i in [2]:
attn_qkvw.register_hook(
lambda x,
i=i,
self=self: grads.append([
x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
("Q_W" if i == 0 else "K_W" if i == 1 else "V_W")
]))
for i in [2]:
attn_qkvb.register_hook(
lambda x,
i=i,
self=self: grads.append([
x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
("Q_B" if i == 0 else "K_B" if i == 1 else "V_B")
]))
attn_ow.register_hook(lambda x, self=self: grads.append([x, "O_W"]))
attn_ob.register_hook(lambda x, self=self: grads.append([x, "O_B"]))
attn_nw.register_hook(lambda x, self=self: grads.append([x, "N2_W"]))
attn_nb.register_hook(lambda x, self=self: grads.append([x, "N2_B"]))
inter_w.register_hook(lambda x, self=self: grads.append([x, "int_W"]))
inter_b.register_hook(lambda x, self=self: grads.append([x, "int_B"]))
output_w.register_hook(lambda x, self=self: grads.append([x, "out_W"]))
output_b.register_hook(lambda x, self=self: grads.append([x, "out_B"]))
norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))
if config.is_grad_enabled:
if (config.pre_layer_norm and config.normalize_invertible):
ctx.save_for_backward(input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b)
else:
ctx.save_for_backward(output,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b)
ctx.config = config
if (config.pre_layer_norm or not config.normalize_invertible):
ctx.inp_norm = inp_norm
ctx.qkv_tf = qkv_tf
ctx.soft_inp = soft_inp
if not config.attn_dropout_checkpoint:
ctx.ctx_bufB = ctx_bufB
ctx.attn_o_inp = attn_o_inp
if not config.normalize_invertible:
ctx.add_res = add_res
ctx.ff1_inp = ff1_inp
if not config.gelu_checkpoint:
ctx.gelu_inp = gelu_inp
ctx.ff2_inp = ff2_inp
ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
ctx.attn_output_dropout_mask = attn_output_dropout_mask
ctx.layer_output_dropout_mask = layer_output_dropout_mask
return output
@staticmethod
def backward(ctx, grad_output):
bsz = grad_output.shape[0]
if bsz > ctx.config.batch_size:
raise ValueError('grad_output batch size exceeds the limit.')
assert ctx.config.training
if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible):
(input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b) = ctx.saved_tensors
else:
(output,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b) = ctx.saved_tensors
cuda_module = ds_stochastic_transformer_cuda if ctx.config.stochastic_mode else ds_transformer_cuda
backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32
(grad_input,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b) = backward_func(
ctx.config.layer_id,
grad_output,
(ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else output),
(ctx.inp_norm if (ctx.config.pre_layer_norm
or not ctx.config.normalize_invertible) else input),
ctx.qkv_tf,
ctx.soft_inp,
(ctx.soft_inp if ctx.config.attn_dropout_checkpoint else ctx.ctx_bufB),
ctx.attn_o_inp,
(ctx.ff1_inp if ctx.config.normalize_invertible else ctx.add_res),
ctx.ff1_inp,
(ctx.ff2_inp if ctx.config.gelu_checkpoint else ctx.gelu_inp),
ctx.ff2_inp,
ctx.attn_prob_dropout_mask,
ctx.attn_output_dropout_mask,
ctx.layer_output_dropout_mask,
(ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else input),
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b)
return (grad_input,
None,
None,
None,
None,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b,
None)
class DeepSpeedTransformerLayer(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Arguments:
layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers,
layer_id will be 0,1,2...23 when each layer object is instantiated
config: An object of DeepSpeedTransformerConfig
initial_weights: Optional: Only used for unit test
initial_biases: Optional: Only used for unit test
"""
def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
super(DeepSpeedTransformerLayer, self).__init__()
self.config = config
self.config.layer_id = layer_id
print("DeepSpeed Transformer config is ", self.config.__dict__)
if self.config.local_rank >= 0:
torch.cuda.set_device(self.config.local_rank)
if initial_weights is None and initial_biases is None:
self.attn_qkvw = nn.Parameter(
torch.Tensor(self.config.hidden_size * 3,
self.config.hidden_size))
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
self.attn_ow = nn.Parameter(
torch.Tensor(self.config.hidden_size,
self.config.hidden_size))
self.attn_ob = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.inter_w = nn.Parameter(
torch.Tensor(4 * self.config.hidden_size,
self.config.hidden_size))
self.inter_b = nn.Parameter(torch.Tensor(4 * self.config.hidden_size))
self.output_w = nn.Parameter(
torch.Tensor(self.config.hidden_size,
4 * self.config.hidden_size))
self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.init_transformer_weights(self.config.adjust_init_range)
else:
# For testing only.
self.attn_qkvw = nn.Parameter(
torch.Tensor(self.config.hidden_size * 3,
self.config.hidden_size))
for i in range(3):
self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
torch.empty_like(initial_weights[i]).copy_(initial_weights[i])
self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
self.attn_qkvb.data.zero_()
self.attn_ow = initial_weights[3]
self.attn_ob = initial_biases[3]
self.attn_nw = initial_weights[4]
self.attn_nb = initial_biases[4]
self.inter_w = initial_weights[5]
self.inter_b = initial_biases[5]
self.output_w = initial_weights[6]
self.output_b = initial_biases[6]
self.norm_w = initial_weights[7]
self.norm_b = initial_biases[7]
# create the layer in cuda kernels.
cuda_module = ds_stochastic_transformer_cuda if self.config.stochastic_mode else ds_transformer_cuda
create_layer_func = cuda_module.create_transformer_layer_fp16 if self.config.fp16 else cuda_module.create_transformer_layer_fp32
create_layer_func(self.config.layer_id,
self.config.batch_size,
self.config.hidden_size,
self.config.heads,
4 * self.config.hidden_size,
self.config.max_seq_length,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.seed,
self.config.pre_layer_norm,
self.config.test_gemm,
self.config.attn_dropout_checkpoint,
self.config.normalize_invertible,
self.config.gelu_checkpoint,
self.config.stochastic_mode)
def init_transformer_weights(self, adjust_init_range=False):
num_layers = self.config.num_hidden_layers
output_std = self.config.initializer_range
if adjust_init_range and self.config.local_rank == 0:
print("Accounting for accumulation on the residual path")
output_std = self.config.initializer_range / math.sqrt(2.0 * num_layers)
self.attn_qkvw.data.normal_(mean=0.0, std=self.config.initializer_range)
self.attn_qkvb.data.zero_()
self.attn_ow.data.normal_(mean=0.0, std=output_std)
self.attn_ob.data.zero_()
self.attn_nw.data.fill_(1.0)
self.attn_nb.data.zero_()
self.inter_w.data.normal_(mean=0.0, std=self.config.initializer_range)
self.inter_b.data.zero_()
self.output_w.data.normal_(mean=0.0, std=output_std)
self.output_b.data.zero_()
self.norm_w.data.fill_(1.0)
self.norm_b.data.zero_()
def forward(self, input, input_mask, grads=None):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
return DeepSpeedTransformerFunction.apply(input,
input_mask,
self,
grads,
self.config.layer_id,
self.attn_qkvw,
self.attn_qkvb,
self.attn_ow,
self.attn_ob,
self.attn_nw,
self.attn_nb,
self.inter_w,
self.inter_b,
self.output_w,
self.output_b,
self.norm_w,
self.norm_b,
self.config)
......@@ -58,7 +58,7 @@ class FusedLamb(torch.optim.Optimizer):
min_coeff=0.01,
amsgrad=False):
global fused_lamb_cuda
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
fused_lamb_cuda = importlib.import_module("deepspeed_lamb_cuda")
if amsgrad:
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
......
......@@ -283,7 +283,9 @@ def main(args=None):
"-u",
"-m",
"deepspeed.pt.deepspeed_launch",
"--world_info={}".format(world_info_base64)
"--world_info={}".format(world_info_base64),
"--master_addr={}".format(args.master_addr),
"--master_port={}".format(args.master_port)
]
cmd = deepspeed_launch + [args.user_script] + args.user_args
else:
......
......@@ -17,6 +17,7 @@ DeepSpeed achieves the fastest BERT training record: 44 minutes on 1,024
NVIDIA V100 GPUs**, compared with the best published result of 67 minutes on
the same number and generation of GPUs.
For a technical overview, see our [blog post](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/).
**Code and tutorials are coming soon!**
* Brief overview, see our [press release](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/).
* Detailed technology deep dive, see our [blog post](https://www.deepspeed.ai/news/2020/05/28/bert-record.html).
* Tutorial on how to reproduce our results, see our [BERT pre-training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/).
* The source code for our transformer kernels can be found in the [DeepSpeed repo](https://github.com/microsoft/deepspeed) and BERT pre-training code can be found in the [DeepSpeedExamples repo](https://github.com/microsoft/deepspeedexamples).
---
layout: single
title: "Microsoft DeepSpeed achieves the fastest BERT training time"
excerpt: ""
categories: news
new_post: true
date: 2020-05-28 00:00:00
---
Good news! **DeepSpeed obtains the fastest BERT training record: 44 minutes on
1024 NVIDIA V100 GPU.** This is a 30% improvement over the best published result
of 67 mins in end-to-end training time to achieve the same accuracy on the same
number and generation of GPUs. This improvement does not come at the cost of
excessive hardware resources but comes from improved software efficiency. For
example, DeepSpeed can attain a staggering 64 teraflops of single GPU
performance on a NVIDIA V100 GPU which is over 50% of the hardware peak.
In this blog post, we will discuss four technological improvements that enable
DeepSpeed to achieve this record-breaking BERT training time.
1. Highly optimized transformer kernels to improve compute efficiency
2. Overlapping I/O with computation through asynchronous prefetching queue
3. Sparse output processing to eliminate wasteful computation
4. Layer-norm reordering for training stability and faster convergence
These optimizations not only benefit BERT; they are also applicable to many
other transformer-based models such as RoBERTa, XLNet, and UniLM.
## Overview of Performance Results
Compared to SOTA, DeepSpeed significantly improves single GPU performance for
transformer-based model like BERT. Figure 1 shows the single GPU throughput of
training BERT-Large optimized through DeepSpeed, comparing with the two
well-known PyTorch implementations from [NVIDIA
BERT](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)
and [Hugging Face
BERT](https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py).
DeepSpeed reaches as high as 64 and 53 teraflops throughputs (corresponding to
272 and 52 samples/second) for sequence lengths 128 and 512, respectively,
exhibiting up to 28% throughput improvements over NVIDIA BERT and up to 62%
over HuggingFace BERT. We also support up to 1.8x larger batch size without
running out of memory.
To achieve this performance, DeepSpeed implements a stochastic transformer
which exhibits some level of non-deterministic noise without affecting overall
convergence. In addition, DeepSpeed also implements a deterministic transformer
kernel that is completely reproducible at the expense of a small performance
regression of approximately 2% on average. Users can easily choose and switch
between the two versions depending on their usage scenarios: Stochastic version
pursues ultimate training performance goal, and deterministic version may save
development time by better facilitating experimentation and debugging. We
report performance numbers for both these kernels in Figure 1. The performance
numbers were collected with a gradient accumulation step of 10 for all batch
sizes and configurations, since on average an overall batch size used in
practical scenarios range from a few hundred to a few thousand.
![Transformer-Kernel-Throughput-128](../assets/images/transformer_kernel_perf_seq128.PNG)
![Transformer-Kernel-Throughput-512](../assets/images/transformer_kernel_perf_seq512.PNG)
Figure 1: Performance evaluation of BERT-Large on a single V100 GPU, comparing
DeepSpeed with NVIDIA and HuggingFace versions of BERT in mixed-sequence length
training. The labeled points show the highest throughput of each implementation
in teraflops (Tflops). DeepSpeed boosts throughput and allows for higher batch
sizes without running out-of-memory.
Looking at distributed training across GPUs, Table 1 shows our end-to-end
BERT-Large pretraining time (F1 score of 90.5 for SQUAD) using 16 to 1024 GPUs.
We complete BERT pretraining in 44 minutes using 1024 V100 GPUs (64 NVIDIA
DGX-2 nodes). In comparison, the previous SOTA from NVIDIA takes 47 mins using
1472 V100 GPUs. DeepSpeed is not only faster but also uses 30% less resources.
Using the same 1024 GPUS,NVIDIA BERT takes 67 minutes using the same 1024 GPUs
[1] BERT, whereas DeepSpeed takes 44 minutes, reducing training time by 30%.
Similarly, on 256 GPUs, NVIDIA BERT takes 236 minutes while DeepSpeed takes 144
minutes (39% faster).
| Number of nodes | Number of V100 GPUs | Time |
| ----------------- | -------------------- | ------------ |
| 1 DGX-2 | 16 | 33 hr 13 min |
| 4 DGX-2 | 64 | 8 hr 41 min |
| 16 DGX-2 | 256 | 144 min |
| 64 DGX-2 | 1024 | 44 min |
Table 1: BERT-Large training time using 1 to 64 DGX-2's with DeepSpeed.
At the recent GTC 2020, NVIDIA announced the next generation hardware A100,
which now offers 2.5X hardware peak performance over the V100 GPU. Assuming
the A100 GPU allows us to obtain the same percentage of hardware peak
performance (50%) as we obtained on V100 GPUs, we expect to obtain even higher
throughput by combining our software optimizations with the new hardware. We
project it would reduce BERT training time further to less than 25 minutes on a
cluster of 1024 A100 GPUs.
## BERT Highly Optimized Transformer Kernels
GPUs have very high peak floating-point throughput, but the default Transformer
blocks in most framework implementations are far from reaching this peak.
Figure 2 shows the structure of a Transformer block with the LayerNorm placed
on the input stream of the two sublayers: Attention and Feed-Forward. To
approach the GPU peak performance, we employ two lines of optimizations in our
own Transformer kernel implementation: advanced fusion, and invertible
operators.
![Transformer-PreLN-Arch](../assets/images/transformer_preln_arch.png)
Figure 2: Transformer Layer with Pre-LayerNorm Architecture
### (a) Advanced fused kernels to reduce data movement
We observe that transformer-based networks trigger many invocations of CUDA
kernels operating in a producer-consumer fashion, adding a lot of cost for
transferring data to and from global memory and overhead from kernel launching.
Existing compiler-based approaches perform fine-grained fusion (e.g., fusion of
element-wise operations), leading to missed fusion opportunities. In contrast,
we fully exploit both fine-grain and coarse-grained fusion, tailored for
Transformer blocks.
**QKV and various fusions.** We merge the three Query (Q), Key (K), and Value (V)
weight matrices to dispatch a larger QKV GEMM to expose more parallelism and
improve data locality on GPU’s shared memory and register files, as shown in
Figure 3. Next, we combine the data-layout transformation of the QKV’s output
matrix with the bias addition. We then partition the large QKV matrix into
three transformed ones, used for the following self-attention computation.
As Figure 3 illustrates, we read the QKV matrix in consecutive rows (shown by
red box), and write them in the three transformed Q, K, and V matrices. Since
each matrix starts from a different offset, we may have uncoalesced access to
the main memory. Thus, we use the shared memory as an intermediate buffer, in
order to rearrange the data in a way that we can put the data in consecutive
parts of memory. Even though we produce an uncoalesced pattern when accessing
shared memory, we reduce the cost of uncoalesced access to main memory to
better exploit memory bandwidth, resulting in 3% to 5% performance improvement
in the end-to-end training.
![QKV-Fusion](../assets/images/qkv_fusion.png)
Figure 3: QKV’s GEMM and transform Kernel-Fusion
We perform additional fusions such as merging the addition of bias from the
attention-output GEMM with the addition from the residual connection and also
dropout, which allows accesses to happen in the register files and shared
memory, which are orders of magnitude faster than the expensive write-back to
the global memory.
**Warp-level communication.** To alleviate the synchronization overhead among
parallel GPU cores and further increase the resource utilization of the fused
kernels, we use the warp-level (data shuffle instructions) instead of the
default inter-warp communication. Take the layer-normalization and SoftMax
kernel as examples, we perform each reduction operation inside a warp, while
distributing different reductions across different warps. This way, we
alleviate the synchronization among the parallel threads and further increase
the GPU resource utilization.
**Stochastic vs deterministic kernels.** DL training is generally robust to some
level of stochasticity, and in some cases, controlled noises such as dropouts
act as regularizer which improve generalization. In designing our transformer
kernel, we embrace some level of stochasticity to improve throughput by
allowing for limited data race conditions to exist in the kernel: We leverage
implicit warp synchronous programming to achieve higher performance for the
warp-level cooperative operations [3]. The lack of explicit warp level
synchronization act as non-deterministic noise without affecting the overall
convergence behavior of the transformer kernels while giving a decent
throughput boost.
In addition, DeepSpeed also implements a non-stochastic transformer kernel with
explicit warp synchronization that produces deterministic results at the
expense of a small performance regression. Users can easily choose and switch
between the two versions depending on their usage scenarios: Stochastic version
pursues ultimate training performance goal, and deterministic version may save
development time by better facilitating experimentation and debugging.
In our experiments, we use stochastic kernels for the pretraining BERT, while
using non-stochastic kernels for fine-tuning to achieve fully reproducible
results. We recommend using stochastic kernels for training tasks involving
massive amounts of data such as pre-training, while using non-stochastic
version when training with limited data such as in the case of fine-tuning for
more consistent results.
**Cost-effective rematerialization.** When fusing kernels of the different
operations, we observe that some operators are inexpressive to compute but
incur expensive data movement cost, such as addition of bias and dropout. For
these operations, we avoid saving their results in the forward pass, but
instead recompute them during the backward pass, which turns out to be much
faster than having their results written and reloaded from the main memory.
### (b) Invertible operators to save memory and run large batches
We also observe that the intermediate activations from several operators in the
Transformer blocks incur a large memory consumption, such as SoftMax and Layer
Norm. For these operators, we drop the inputs to these layers to reduce the
footprint of activation memory, by leveraging the fact that they are invertible
functions, which are functions whose backward pass is independent of the inputs
and can be formulated based only on the outputs [2]. Figure 4 and Figure 5 show
the examples of the original implementation of SoftMax and Layer-Norm in
PyTorch versus the invertible SoftMax implementation in DeepSpeed. Through this
optimization, we are able to reduce the activation memory of the operator by
half, and the reduced memory allows us to train with larger batch sizes, which
once again improves GPU efficiency.
![Softmax-torch](../assets/images/softmax_pytorch.gif)
![Softmax-DS](../assets/images/softmax_deepspeed.gif)
Figure 4: DeepSpeed invertible SoftMax operation versus Default PyTorch SoftMax operation
![LayerNorm-DS](../assets/images/layernorm_pytorch.gif)
![LayerNorm-DS](../assets/images/layernorm_deepspeed.gif)
Figure 5: DeepSpeed invertible LayerNorm operation versus Default PyTorch LayerNorm operation
## Overlapping I/O with Computation through Asynchronous Prefetching Queue
Beyond highly optimized transformer kernels, the BERT training has other
performance limiting factors, e.g., data loading. We develop our own
asynchronous worker which prefetches batches of data into a queue only at “safe
points” -- points when the CPUs are idle (e.g., right after asynchronously
launching the forward pass). In this way, we make sure that there is no
dequeuing and copying data from CPU to GPU when there is computation on the CPU
side. This is different from the default PyTorch data loader, which can
prefetch data at any points and cause performance interference. By using this
method, we hide almost all I/O overhead, which accounts for 4% of the original
training time.
## Exploiting Sparsity of BERT’s Output Processing
We improve the end-to-end training time by 5.4% by recognizing and exploiting
sparsity in BERT’s output processing. The output processing involves two steps:
i) BERT projection from the hidden output dimension of the final transformer
layer to the language vocabulary, using a matrix-matrix multiplication, and ii)
a cross-entropy of the masked output tokens to the get each sequence’s
prediction error. The cost of the first step is proportional to the vocabulary
size, hidden output dimension and the sequence length, and can be as expensive
as a transformer layer computation or more. However, only about 15% of the
tokens are masked, and we only need the cross-entropy for the masked tokens.
Therefore, the projection can be done as an efficient sparse computation. To do
so, we discard the rows of the final transformer layer that corresponding to
the non-masked tokens before doing the projection, reducing the computation
cost of output processing by 85%.
## Pre-LayerNorm vs Post-LayerNorm Architecture
We observe that with large batch size (e.g., 64K) the default BERT pre-training
suffers from training instability, which can result in model divergence or
convergence to bad/suspicious local optima. Further investigation shows that
the default BERT has vanishing gradients issue. To mitigate the issue, we
changed the placement of LayerNorm (Post-LayerNorm) by placing it only on the
input stream of the sublayers in the Transformer block (called Pre-LayerNorm),
a modification described by several recent works for neural machine
translation. The Pre-LayerNorm results in several useful characteristics such
as avoiding vanishing gradient, stable optimization, and performance gain. It
allows us to train at aggregated batch size of 64K with increased learning rate
and faster convergence.
To try out these optimizations and training recipe, please check out our [BERT
training tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/) and
source code at the [DeepSpeed GitHub
repo](https://github.com/microsoft/deepspeed).
### References
[1] "NVIDIA Clocks World’s Fastest BERT Training Time and Largest Transformer Based Model, Paving Path For Advanced Conversational AI" [https://devblogs.nvidia.com/training-bert-with-gpus/](https://devblogs.nvidia.com/training-bert-with-gpus/).
[2] S. R. Bulo, L. Porzi, and P. Kontschieder, "In-place activated batch norm for memory-optimized training of dnns" 2017. [http://arxiv.org/abs/1712.02616](http://arxiv.org/abs/1712.02616).
[3] Mark Harris and Kyrylo Perelygin, "Cooperative Groups: Flexible CUDA Thread Programming", [https://devblogs.nvidia.com/cooperative-groups/]( https://devblogs.nvidia.com/cooperative-groups/).
---
title: "BingBertSQuAD Fine-tuning"
excerpt: ""
---
In this tutorial we will be adding DeepSpeed to the BingBert model for the SQuAD fine-tuning task, called "BingBertSquad" henceforth. We will also demonstrate performance gains.
## Overview
Please clone the repository and go to the `examples/BingBertSquad` folder to follow along.
### Pre-requisites
* Download SQuAD data:
* Training set: [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* Validation set: [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
You also need a pre-trained BERT model checkpoint from DeepSpeed. We will use checkpoint 162 from the
BERT pre-training [tutorial](bert-pretraining).
* Pre-training checkpoint: `training_state_checkpoint_162.tar`
Note that the BERT model in the file `train-v1.1.json_bert-large-uncased_384_128_64` is not strictly required as it will be downloaded automatically if it is not present locally on the cluster.
### Running BingBertSquad
- **Unmodified (BaseLine):** If you would like to run unmodified BingBertSquad with the pre-processed data, there is a helper script which you can invoke via: `bash run_squad_baseline.sh 8 <PATH_TO_CHECKPOINT>/training_state_checkpoint_162.tar` where the first argument `8` is the number of GPUs and the second argument is the path to the pre-training checkpoint. This bash script sets the parameters and invokes `nvidia_run_squad_baseline.py`.
- **Modified (DeepSpeed):** This is similar to baseline; just substitute `run_squad_baseline.sh` with `run_squad_deepspeed.sh` which invokes `nvidia_run_squad_deepspeed.py`.
## DeepSpeed Integration
The main DeepSpeed modified script is `nvidia_run_squad_deepspeed.py`; the line `import deepspeed` enables you to use DeepSpeed.
Make sure that the number of GPUs specified in the job are available (else, this will yield an out of memory error). The wrapper script `run_BingBertSquad.sh` and the test script `run_tests.sh` essentially serve to automate training - they may also be used a guidelines to set parameters and launch the fine-tuning task.
### Configuration
The `deepspeed_bsz24_config.json` file gives the user to specify DeepSpeed options in terms of batch size, learning rate, precision and other parameters. When running the `nvidia_run_squad_deepspeed.py`, in addition to the `-d` flag to enable DeepSpeed, the appropriate DeepSpeed configuration file must be specified using `--deepspeed_config <deepspeed_bsz24_config.json>`.
### Argument Parsing
The first step to apply DeepSpeed is adding arguments to BingBertSquad, using `deepspeed.add_config_arguments()` in the beginning of the main entry point as in the `main()` function in `nvidia_run_squad_deepspeed.py`. The argument passed to `add_config_arguments()` is obtained from the `get_argument_parser()` function in utils.py.
```python
parser = get_argument_parser()
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
```
Similar to this, all the options with their corresponding description are available in `utils.py`.
### Training
#### Initialization
DeepSpeed has an initialization function to create model, optimizer and LR scheduler. For BingBertSquad, we simply augment the Baseline script with the initialize function as follows.
```python
model, optimizer, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=optimizer_grouped_parameters,
dist_init_required=False)
```
Another feature of DeepSpeed is its convenient `step()` function which can be called directly as `model.step()` which hides the `fp16_optimizer` away from the user as opposed to `optimizer.step()` in the baseline code (similar to other models in this tutorial) which needs explicit handling of the case of FP16 computation.
#### Forward pass
This is identical in both Baseline and DeepSpeed, and is performed by `loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)`.
#### Backward pass
In the Baseline script you need to handle the all-reduce operation at the gradient accumulation boundary explicitly by using `enable_need_reduction()` followed by `optimizer.backward(loss)` in FP16 and `loss.backward()` in FP32. In DeepSpeed, you may simply do `model.backward(loss)`.
#### Weight updates
In the Baseline Script, you are required to explicitly specify the optimizer as `FusedAdam` (along with the handling of dynamic loss scaling) in FP16 and `BertAdam` in FP32, followed by the call `optimizer.step()` and `optimizer.zero_grad()`. DeepSpeed handles this internally (by setting the optimizer using the JSON config) when `initialize()` is called and thus you don't need to explicitly write code but just do `model.step()`.
Congratulations! Porting into DeepSpeed is complete.
### Evaluation
Once training is complete, the EM and F1 scores may be obtained from the following command:
`python evaluate-v1.1.py <PATH_TO_DEVSET>/dev-v1.1.json <PATH_TO_PREDICTIONS>/predictions.json`
### DeepSpeed Improvements
The table summarizing the results are given below. In all cases, the batch size is set to 24 and the training is conducted on 8 GPUs for 2 epochs on the DLTS RR1 DGX-2 hypercluster. A set of parameters (seeds and learning rates) were tried and the best ones were selected. All learning rates was 3e-5; Baseline seeds were 42 and DeepSpeed seeds were 10.
| Case | Precision | EM | F1 | Throughput |
| --------- | --------- | ----- | ----- | ---------- |
| DeepSpeed | FP16 | 84.38 | 91.11 | 9.6 |
| Baseline | FP16 | 84.39 | 91.29 | 8.4 |
| DeepSpeed | FP32 | 84.20 | 91.06 | 3.7 |
| Baseline | FP32 | 84.20 | 90.91 | 2.7 |
In terms of throughput (expressed in iterations processed per second), we note that DeepSpeed outperforms the baseline for the desired accuracy (in terms of EM, F1 scores).
## Fine-tuning the model pre-trained with DeepSpeed Transformer Kernels
For pre-training your model, please see [BERT Pre-Training](\bert-pretraining\) tutorial for the detailed instrucions.
If you already obtained the checkpoint of your model, use the following configuration to finetune your pretrained checkpoint.
| Parameters | Value |
| ------------------------ | ------------------------- |
| Total batch size | 24 |
| Train micro batch size per gpu | 3 |
| Optimizer | Adam |
| Learning rate | 4e-5 |
| Sequence-length | 384 |
| Weight-decay | 0.0 |
| Epoch count | 2 |
### Enabling DeepSpeed's Transformer Kernel
DeepSpeed's optimized transformer kernel must be enabled during fine-tuning
if and only if it was used also during pre-training, because the transformer
kernel has its own parameters saved in checkpoint files.
To enable the transformer kernel for higher performance, first add an argument
`--deepspeed_transformer_kernel` in `utils.py`, we can set it as `False` by
default, for easily turning on/off.
```python
parser.add_argument('--deepspeed_transformer_kernel',
default=False,
action='store_true',
help='Use DeepSpeed transformer kernel to accelerate.')
```
Then in the `BertEncoder` class of the modeling source file, instantiate
transformer layers using the DeepSpeed transformer kernel as below.
```python
if args.deepspeed_transformer_kernel:
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig, DeepSpeedConfig
if hasattr(args, 'deepspeed_config') and args.deepspeed_config:
ds_config = DeepSpeedConfig(args.deepspeed_config)
else:
raise RuntimeError('deepspeed_config is not found in args.')
cuda_config = DeepSpeedTransformerConfig(
batch_size = ds_config.train_micro_batch_size_per_gpu,
max_seq_length = args.max_seq_length,
hidden_size = config.hidden_size,
heads = config.num_attention_heads,
attn_dropout_ratio = config.attention_probs_dropout_prob,
hidden_dropout_ratio = config.hidden_dropout_prob,
num_hidden_layers = config.num_hidden_layers,
initializer_range = config.initializer_range,
seed = args.seed,
fp16 = ds_config.fp16_enabled,
pre_layer_norm=True)
self.layer = nn.ModuleList([copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) for i in range(config.num_hidden_layers)])
else:
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
```
All configuration settings come from the DeepSpeed configuration file and
command arguments and thus we must pass the `args` variable to here in this model.
Note:
1. `batch_size` is the maximum bath size of input data, all fine-tuning training data or prediction data shouldn't exceed this threshold, otherwise it will throw an exception. In the DeepSpeed configuration file micro batch size is defined as `train_micro_batch_size_per_gpu`, e.g. if it is set as 8 and prediction uses batch size of 12, we can use 12 as transformer kernel batch size, or using "--predict_batch_size" argument to set prediction batch size to 8 or a smaller number.
2. `local_rank` in DeepSpeedTransformerConfig is used to assign the transformer kernel to the correct device. Since the model already runs set_device() before here, so does not need to be set here.
For more details about the transformer kernel, please see [DeepSpeed Transformer Kernel](/transformer_kernel/) and [DeepSpeed Fast-Bert Training](/fast_bert/).
### Dropout Setting
For the fine-tuning, we only use the deterministic transformer to have reproducible the fine-tuning results. But, we choose different values for dropout based on whether pre-training was done using deterministic or stochastic transformer (Please see [Transformer tutorial](/transformer_kernel/) for more detail of selecting these two modes).
For model pre-trained with deterministic transformer, we use the same dropout ration used in pretraining (0.1). However, we slightly increase the dropout ratio when fine-tuning the model pre-trained using the stochastic transformer to compensate for the lack of stochastic noise during fune-tuning.
| Pretraining mode | Dropout ratio |
| ------------------------ | ------------------------- |
| Determinstic | 0.1 |
| Stochastic | 0.12 - 0.14 |
### Results
Fine-tuning the model pre-trained usng DeepSpeed Transformer and the recepie in [DeepSpeed Fast-Bert Training](/fast_bert/) should yield F1 score of 90.5 and is expected to increase if you let the pre-training longer than suggested in the tutorial.
......@@ -3,12 +3,6 @@ title: "BERT Pre-training"
excerpt: ""
---
**Note:**
This tutorial will be updated to include new details for reproducing the
recent 44-minute [BERT pre-training record](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/).
Please check again soon!
{: .notice--warning}
In this tutorial we will apply DeepSpeed to pre-train the BERT
(**B**idirectional **E**ncoder **R**epresentations from **T**ransformers),
which is widely used for many Natural Language Processing (NLP) tasks. The
......@@ -29,6 +23,7 @@ We work from adaptations of
We have forked this repo under
[DeepSpeedExamples/bing_bert](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert)
and made several modifications in their script:
* We adopted the modeling code from NVIDIA's BERT under `bing_bert/nvidia/`.
* We extended the data pipeline from [Project Turing](https://msturing.org/)
under `bing_bert/turing/`.
......@@ -246,13 +241,76 @@ modifications. We have included a modified `train.py` file called
applied.
### Enabling DeepSpeed's Transformer Kernel
To enable the transformer kernel for higher performance, first add an argument
`--deepspeed_transformer_kernel` in `utils.py`, we can set it as `False` by
default, for easily turning on/off.
```python
parser.add_argument('--deepspeed_transformer_kernel',
default=False,
action='store_true',
help='Use DeepSpeed transformer kernel to accelerate.')
```
Then in the `BertEncoder` class of the modeling source file, instantiate
transformer layers using DeepSpeed transformer kernel as below.
```python
if args.deepspeed_transformer_kernel:
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig, DeepSpeedConfig
if hasattr(args, 'deepspeed_config') and args.deepspeed_config:
ds_config = DeepSpeedConfig(args.deepspeed_config)
else:
raise RuntimeError('deepspeed_config is not found in args.')
cuda_config = DeepSpeedTransformerConfig(
batch_size = ds_config.train_micro_batch_size_per_gpu,
max_seq_length = args.max_seq_length,
hidden_size = config.hidden_size,
heads = config.num_attention_heads,
attn_dropout_ratio = config.attention_probs_dropout_prob,
hidden_dropout_ratio = config.hidden_dropout_prob,
num_hidden_layers = config.num_hidden_layers,
initializer_range = config.initializer_range,
local_rank = args.local_rank if hasattr(args, 'local_rank') else -1,
seed = args.seed,
fp16 = ds_config.fp16_enabled,
pre_layer_norm=True,
attn_dropout_checkpoint=args.attention_dropout_checkpoint,
normalize_invertible=args.normalize_invertible,
gelu_checkpoint=args.gelu_checkpoint,
stochastic_mode=True)
self.layer = nn.ModuleList([copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config)) for i in range(config.num_hidden_layers)])
else:
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
```
All configuration settings come from the DeepSpeed configuration file and
command arguments and thus we must pass the `args` variable to here in this model.
Note:
1. `batch_size` is the maximum bath size of input data, all fine-tuning training data or prediction data shouldn't exceed this threshold, otherwise it will throw an exception. In the DeepSpeed configuration file micro batch size is defined as `train_micro_batch_size_per_gpu`, e.g. if it is set as 8 and prediction uses batch size of 12, we can use 12 as transformer kernel batch size, or using "--predict_batch_size" argument to set prediction batch size to 8 or a smaller number.
2. `local_rank` in DeepSpeedTransformerConfig is used to assign the transformer kernel to the correct device. Since the model already runs set_device() before here, so does not need to be set here.
3. `stochastic_mode` has higher performance when it is enabled, we enable it in pre-training, and disable it in fine-tuning.
4. The transformer kernel has its own parameters and so the checkpoint files
generated with transformer kernel must to be loaded by the model with
transformer kernel enabled (such as in fine-tuning).
For more details about the transformer kernel, please see [DeepSpeed Transformer Kernel](/transformer_kernel/) and [DeepSpeed Fast-Bert Training](/fast_bert/).
### Start Training
An example of launching `deepspeed_train.py` on four nodes with four GPUs each would be:
```bash
deepspeed --num_nodes 4 \
deepspeed_train.py \
--deepspeed \
--deepspeed_config deepspeed_bsz4096_adam_config.json
--deepspeed_config deepspeed_bsz4096_adam_config.json \
--cf /path-to-deepspeed/examples/tests/bing_bert/bert_large_adam_seq128.json \
--train_batch_size 4096 \
--max_seq_length 128 \
......@@ -263,6 +321,7 @@ deepspeed --num_nodes 4 \
--delay_allreduce \
--max_steps 32 \
--print_steps 1 \
--deepspeed_transformer_kernel \
--output_dir <output_directory>
```
See the [Getting Started](/getting-started/) guide for more information on
......@@ -270,57 +329,52 @@ launching DeepSpeed.
------
## Reproducing BERT Training Results with DeepSpeed
## Reproducing Fastest BERT Training Results with DeepSpeed
Our BERT training result is competitive across the industry in terms of
achieving F1 score of 90.5 or better on the SQUAD 1.1 dev set:
We achieve the fastest BERT training time while remaining competitive across the industry in terms of achieving F1 score of 90.5 or better on the SQUAD 1.1 dev set. Please follow the [Fine-tuning](/bert-finetuning/) tutorial to finetune your model pretrained by transformer kernel and reprodue the SQUAD F1 score.
- We complete BERT pretraining in 44 minutes using 1024 V100 GPUs (64 NVIDIA DGX-2 nodes). In comparison, the previous SOTA from NVIDIA takes 47 mins using 1472 V100 GPUs. DeepSpeed is not only faster but also uses 30% less resources. Using the same 1024 GPUS, NVIDIA BERT is 52% slower than DeepSpeed, taking 67 minutes to train.
- Comparing with the original BERT training time from Google, it took them
about 96 hours to reach parity on 64 TPU2 chips, while it took us 14 hours on
about 96 hours to reach parity on 64 TPU2 chips, while it took us less than 9 hours on
4 DGX-2 nodes of 64 V100 GPUs.
- On 256 GPUs, it took us 3.7 hours, faster than state-of-art result (3.9
- On 256 GPUs, it took us 2.4, faster than state-of-art result (3.9
hours) from Nvidia using their superpod on the same number of GPUs
([link](https://devblogs.nvidia.com/training-bert-with-gpus/)).
![BERT Training Time](/assets/images/bert-large-training-time.png){: .align-center}
![DeepSpeed BERT Training Time](/assets/images/end-to-end-bert-training.png){: .align-center}
Our configuration for the BERT training result above can be reproduced with
the scripts/json configs in our DeepSpeed repo. Below is a table containing a
the scripts/json configs in our DeepSpeedExamples repo. Below is a table containing a
summary of the configurations. Specifically see the
`ds_train_bert_bsz16k_seq128.sh` and `ds_train_bert_bsz16k_seq512.sh` scripts
`ds_train_bert_bsz64k_seq128.sh` and `ds_train_bert_bsz32k_seq512.sh` scripts
for more details in
[DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert).
| Parameters | 128 Sequence | 512 Sequence |
| ------------------------ | ------------------------- | ------------------------- |
| Total batch size | 16K | 16K |
| Total batch size | 64K | 32K |
| Train micro batch size per gpu | 64 | 8 |
| Optimizer | Lamb | Lamb |
| Learning rate | 4e-3 | 1e-3 |
| Min Lamb coefficient | 0.08 | 0.08 |
| Max Lamb coefficient | 0.5 | 0.5 |
| Learning rate scheduler | `warmup_linear_decay_exp` | `warmup_linear_decay_exp` |
| Warmup proportion | 0.02 | 0.01 |
| Decay rate | 0.90 | 0.70 |
| Decay step | 1000 | 1000 |
| Max Training steps | 187000 | 18700 |
| Rewarm LR | N/A | True |
| Output checkpoint number | 150 | 162 |
| Sample count | 402679081 | 34464170 |
| Iteration count | 24430 | 2089 |
## DeepSpeed Throughput Results
We have measured the throughput results of DeepSpeed using both the Adam
optimizer and LAMB optimizer. We measure the throughput by measuring the wall
clock time to process one mini-batch and dividing the mini-batch size with
the elapsed wall clock time. The table below shows that for sequence length 128,
DeepSpeed achieves 200 samples/sec throughput on a single V100 GPU, and it
obtains 53X and 57.4X speedups over the single GPU run for Adam and LAMB
respectively:
![](/assets/images/deepspeed-throughput-seq128.png){: .align-center}
![](/assets/images/deepspeed-throughput-seq512.png){: .align-center}
| Learning rate | 11e-3 | 2e-3 |
| Initial learning rate (`lr_offset`) | 10e-4 | 0.0 |
| Min Lamb coefficient | 0.01 | 0.01 |
| Max Lamb coefficient | 0.3 | 0.3 |
| Learning rate scheduler | `warmup_exp_decay_exp` | `warmup_exp_decay_exp` |
| Warmup proportion | 0.02 | 0.02 |
| Decay rate | 0.90 | 0.90 |
| Decay step | 250 | 150 |
| Max training steps | 7500 | 7500 |
| Rewarm learning rate | N/A | True |
| Output checkpoint number | 150 | 160-162 |
| Sample count | 403M | 18-22M |
| Epoch count | 150 | 160-162 |
## DeepSpeed Single GPU Throughput Results
![DeepSpeed Single GPU Bert Training Throughput](/assets/images/single-gpu-throughput.png){: .align-center}
Compared to SOTA, DeepSpeed significantly improves single GPU performance for transformer-based model like BERT. Figure above shows the single GPU throughput of training BertBERT-Large optimized through DeepSpeed, compared with two well-known Pytorch implementations, NVIDIA BERT and HuggingFace BERT. DeepSpeed reaches as high as 64 and 53 teraflops throughputs (corresponding to 272 and 52 samples/second) for sequence lengths of 128 and 512, respectively, exhibiting up to 28% throughput improvements over NVIDIA BERT and up to 62% over HuggingFace BERT. We also support up to 1.8x larger batch size without running out of memory.
For more details on how we achieve the record breaking BERT training time please check out deep dive into DeepSpeed BERT [Fastest BERT Training](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
---
title: "DeepSpeed Transformer Kernel"
---
This tutorial shows how to enable the DeepSpeed transformer kernel and set its different configuration parameters.
## DeepSpeed Transformer Kernel
Transformer layers are ubiquitous in many recent sequence-processing models,
such as Natural-Language-Processing. Thus, training transformer-based networks
requires to be highly efficient in term of performance, in order to allow scientists to
explore different models across various application domains in a reasonable amount of time.
To this end, we have developed a new kernel for transformer networks which includes several
optimizations specific to these layers, which boost the training throughput on single GPU and scales
well as we increase the number of GPUs. For more information on the details of transformer kernel, please visit our recent blog post on the [fastest Bert training](/fast_bert/).
## Prerequisites
To use transformer kernel for training a model, you should Integrate DeepSpeed into your training script using the [Getting Started](/getting-started/) guide.
### **Integrate Transformer Kernel**
First of all, you need to integrate transformer kernel into the top-level model. Here, we show an example of instantiating the transformer kernel using the Pre-LN BERT-Large configuration settings. This configuration has 24 layers with 1024 hidden-dimension and uses the sequence length of 128 and batch size of 64. To add all these layers, we copy the same layer specification `num_hidden_layer` times with different IDs inside a ModuleList.
```python
config = DeepSpeedTransformerConfig(batch_size = 64,
max_seq_length = 128,
hidden_size = 1024,
heads = 16,
attn_dropout_ratio = 0.1,
hidden_dropout_ratio = 0.1,
num_hidden_layers = 24,
initializer_range = 0.02,
local_rank = 0,
seed = 1234,
fp16 = True,
pre_layer_norm=True,
attn_dropout_checkpoint=False,
normalize_invertible=False,
gelu_checkpoint=False)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, cuda_config))
for i in range(config.num_hidden_layers)
])
```
### Transformer kernel Parameters
The transformer kernel is configured by a number of parameters which allow users to
explore different settings. We partition these parameters into three categories:
1. General configuration, used by different types of transformer layers
2. Environment parameters, specifying the system's setting
3. High-performance flag, optimizing training with the stochastic computation
4. Memory optimization flags, trade off computing power for memory
The general parameters for configuring the transformer kernel are:
1. `batch_size`: The micro-batch size used for running the kernel on each GPU
2. `max_seq_length`: The sequence-length of the model being trained with DeepSpeed
3. `hidden_size`: The hidden size of the transformer layer
4. `heads`: The number of heads in the self-attention of the transformer layer
5. `attn_dropout_ratio`: The ratio of dropout for the attention's output
6. `hidden_dropout_ratio`: The ratio of dropout for the transformer's output
7. `num_hidden_layers`: The number of transformer layers
8. `pre_layer_norm`: Select between Pre-LN or Post-LN transformer architecture
The environment parameters of the transformer kernel includes:
1. `local_rank`: The rank of the current GPU running the transformer kernel
2. `seed`: The random seed for the dropout layer
3. `fp16`: Enable half-precision computation
4. `initializer_range`: Bert's initializer range
High-performance optimization flag:
1. `stochastic_mode`: By turning on this flag, the training can run faster by 2% on average. Note, that this flag has some level of non-determinism and can produce different results on different runs. However, we have seen that by enabling it, the pretraining tasks such as BERT are not affected and can obtain a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend to turn it off in order to be able to reproduce the same result through the regular kernel execution.
The memory-optimization flags consist of:
1. `attn_dropout_checkpoint`: Enable checkpointing of attention dropout to save memory
2. `normalize_invertible`: Enable invertible LayerNorm execution (dropping the input activation)
3. `gelu_checkpoint`: Enable checkpointing of Gelu activation output to save memory
To illustrate the required model configuration changes to use transformer kernel in model training, we use a Bert model and go through the different configurations in order to support the different sequence lengths and batch sizes. Please see the instruction at [Bert training tutorial](/bert-pretraining/).
### **Memory Optimization Flags**
We provide several techniques into the transformer kernel which saves the memory at different parts of a layer. We expose them as the configurable settings that can be enabled when calling the kernel. By turning on each of these optimization flags, we can support larger batch sizes. Even though we trade off performance for memory using some of these techniques, the end-to-end training efficiency increases by using the larger batch size.
By setting the `normalize_invertible` flag, we force the kernel to drop the input activations to the normalize layers of transformer. We can do this since the kernel includes an optimization to compute the gradients of the parameters and the input to this layer by only using the output activations.
The `attn_dropout_checkpoint` and `gelu_checkpoint` flags refer to the checkpointing approach, in which we drop the inputs to some parts of the transformer layer, attention dropout and Gelu, in order to save an important part of the activation memory. Based on our performance profiling, the performance cost of rematerializing these two are negligible and finally the performance benefit that we gain from running larger batch size compensate for that.
The following table shows which memory optimization flags need to be turned on when running Bert-Large on NVIDIA V100 GPU with 32GB of memory, considering different micro-batch sizes and sequence lengths. For the two sequence lengths, 128 and 512, used in our experiments, we have seen that larger batch size improves the overall training performance for both. Please see our [Bert-Fast](/fast-bert/) blog post for more information regarding the performance evaluation of these configurations.
| Micro-batch size | 128 sequence-length | 512 sequence-length |
| :--------------: | :-----------------------: | :--------------------------------------: |
| > 12 | - | `attn_dropout_checkpoint` |
| > 16 | - | `normalize_invertible`, `gelu_checkpoint`|
| > 80 | `normalize_invertible` | OOM |
| > 112 | `attn_dropout_checkpoint` | OOM |
| > 128 | `gelu_checkpoint` | OOM |
### **Enable Transformer Kernel**
As mentioned earlier, in order to run the transformer network using the custom DeepSpeed kernel, we only need to pass the `deepspeed_transformer_kernel` option when running the training script. Below, we show an example of how we pass this parameter to the `deepspeed` launcher, besides the rest of parameters for the Bert pre-training task.
```bash
deepspeed deepspeed_train.py \
--cf bert_large_lamb.json \
--max_seq_length 512 \
--print_steps 100 \
--deepspeed \
--deepspeed_transformer_kernel \
--deepspeed_config deepspeed_bsz32K_lamb_config_seq512.json \
--rewarmup \
--lr_schedule "EE" \
--lr_offset 0.0 \
--attention_dropout_checkpoint \
--load_training_checkpoint ${CHECKPOINT_BASE_PATH} \
--load_checkpoint_id ${CHECKPOINT_EPOCH150_NAME}
```
In addition to transformer kernel flag, we can specify the memory optimization settings as discussed earlier. As an example, we use the `attention_dropout_checkpoint` option here for running the sequence length 512, in order to run the micro-batch size of 16 at each GPU. If larger batch size is required, we can turn on the rest of memory optimization flags too.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment