Unverified Commit c78c29f9 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

supporting different hidden dimensions (#559)



* supporting different hidden dimensions

* add support for larger hidden dimensions (greater than 8K)

* remove empty line

* add loop unrolling factor for dropout kernels

* update different kernels based on the reviews
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 17f36f1b
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#define MAX_THREAD_ITERATIONS 8 // Maximum 8K #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
#define MAX_WARP_NUM 32 #define MAX_WARP_NUM 32
#define MAX_REGISTERS 256
// Fused bias add with gelu activation // Fused bias add with gelu activation
template <typename T> template <typename T>
void launch_bias_gelu(const T* input, void launch_bias_gelu(const T* input,
......
...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, ...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m, m,
n, n,
k, k,
......
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N, __global__ void dropout_kernel(const int N,
const float ratio, const float ratio,
float* out, float* out,
...@@ -13,17 +15,17 @@ __global__ void dropout_kernel(const int N, ...@@ -13,17 +15,17 @@ __global__ void dropout_kernel(const int N,
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state); curand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
uint8_t m[4]; uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
int i = j * 4; int i = j * unroll_factor;
mask[i] = (uint8_t)m[0]; mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1]; mask[i + 1] = (uint8_t)m[1];
...@@ -35,6 +37,18 @@ __global__ void dropout_kernel(const int N, ...@@ -35,6 +37,18 @@ __global__ void dropout_kernel(const int N,
out[i + 2] = Xdata[i + 2] * scale * m[2]; out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3]; out[i + 3] = Xdata[i + 3] * scale * m[3];
} }
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
} }
__global__ void dropout_kernel(const int N, __global__ void dropout_kernel(const int N,
...@@ -66,7 +80,7 @@ __global__ void dropout_kernel(const int N, ...@@ -66,7 +80,7 @@ __global__ void dropout_kernel(const int N,
__half2 mask_h[2]; __half2 mask_h[2];
float2 mask_f[2]; float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
float2 x_f = x_cast[j]; float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f); __half2* x_h = reinterpret_cast<__half2*>(&x_f);
...@@ -80,7 +94,7 @@ __global__ void dropout_kernel(const int N, ...@@ -80,7 +94,7 @@ __global__ void dropout_kernel(const int N,
float* mask_f_data = &mask_f[0].x; float* mask_f_data = &mask_f[0].x;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]); for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]); mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]); mask_h[1] = __float22half2_rn(mask_f[1]);
...@@ -95,16 +109,16 @@ __global__ void dropout_kernel(const int N, ...@@ -95,16 +109,16 @@ __global__ void dropout_kernel(const int N,
#else #else
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
int i = j * 4; int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i); const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2]; float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]); vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]); vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[4]; uint8_t m[unroll_factor];
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
...@@ -123,6 +137,18 @@ __global__ void dropout_kernel(const int N, ...@@ -123,6 +137,18 @@ __global__ void dropout_kernel(const int N,
} }
#endif #endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
} }
__global__ void dropout_kernel_bwd(const int N, __global__ void dropout_kernel_bwd(const int N,
...@@ -133,15 +159,20 @@ __global__ void dropout_kernel_bwd(const int N, ...@@ -133,15 +159,20 @@ __global__ void dropout_kernel_bwd(const int N,
std::pair<uint64_t, uint64_t> seed) std::pair<uint64_t, uint64_t> seed)
{ {
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
int i = j * 4; int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0; out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
} }
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
} }
__global__ void dropout_kernel_bwd(const int N, __global__ void dropout_kernel_bwd(const int N,
...@@ -161,18 +192,20 @@ __global__ void dropout_kernel_bwd(const int N, ...@@ -161,18 +192,20 @@ __global__ void dropout_kernel_bwd(const int N,
float2* out_cast = reinterpret_cast<float2*>(out); float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask); uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
float2 x_f = x_cast[j]; float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f); __half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint8_t* m = reinterpret_cast<uint8_t*>(mask_cast + j); uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2]; __half2 mask_h[2];
float2 mask_f[2]; float2 mask_f[2];
float* mask_f_data = &mask_f[0].x; float* mask_f_data = &mask_f[0].x;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) mask_f_data[i] = (float)(m[i]); for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll #pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
...@@ -191,9 +224,9 @@ __global__ void dropout_kernel_bwd(const int N, ...@@ -191,9 +224,9 @@ __global__ void dropout_kernel_bwd(const int N,
const __half h_scale = __float2half(scale); const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0); const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
int i = j * 4; int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i); const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
...@@ -211,6 +244,13 @@ __global__ void dropout_kernel_bwd(const int N, ...@@ -211,6 +244,13 @@ __global__ void dropout_kernel_bwd(const int N,
} }
#endif #endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
} }
template <typename T> template <typename T>
...@@ -223,7 +263,9 @@ void launch_dropout(T* out, ...@@ -223,7 +263,9 @@ void launch_dropout(T* out,
cudaStream_t stream, cudaStream_t stream,
bool bwd) bool bwd)
{ {
dim3 grid_dim = DS_GET_BLOCKS(total_count / 4); assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS; dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) { if (dim > 512) {
...@@ -264,55 +306,70 @@ __global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata ...@@ -264,55 +306,70 @@ __global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) __global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{ {
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale); const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata); float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask); uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / 4) CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
uint8_t* m = reinterpret_cast<uint8_t*>(mask_cast + j); float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2]; __half2 mask_h[2];
float2 mask_f[2]; float2 mask_f[2];
float* mask_f_data = &mask_f[0].x; float* mask_f_data = &mask_f[0].x;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) *(mask_f_data++) = (float)(m[i]); for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]); mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]); mask_h[1] = __float22half2_rn(mask_f[1]);
float2 x_data = x_cast[j];
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_data_h[0] * h_scale * mask_h[0]; result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1]; result_h[1] = x_data_h[1] * h_scale * mask_h[1];
x_cast[j] = result_f;
}
#else #else
CUDA_1D_KERNEL_LOOP(j, N / 2) __half* x_data_h = reinterpret_cast<__half*>(&x_data);
{ float2 result[2];
int i = j * 2;
Xdata[i] = (__half)((float)Xdata[i] * scale * mask[i]); result[0].x = (float)x_data_h[0] * scale * m[0];
Xdata[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]); result[0].y = (float)x_data_h[1] * scale * m[1];
} result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif #endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
} }
template <typename T> template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
{ {
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / 2), DS_CUDA_NUM_THREADS, 0, stream>>>( dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
total_count, scale, vals, mask); DS_CUDA_NUM_THREADS,
0,
stream>>>(total_count, scale, vals, mask);
} }
template void launch_dropout_grad(float* vals, template void launch_dropout_grad(float* vals,
...@@ -341,11 +398,38 @@ __global__ void dropout_grad_kernel(const int N, ...@@ -341,11 +398,38 @@ __global__ void dropout_grad_kernel(const int N,
__half* out, __half* out,
uint8_t* mask) uint8_t* mask)
{ {
CUDA_1D_KERNEL_LOOP(j, N / 2) const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{ {
int i = j * 2; float2 x_data = x_cast[j];
out[i] = (__half)((float)Xdata[i] * scale * mask[i]); uint32_t m_32 = mask_cast[j];
out[i + 1] = (__half)((float)Xdata[i + 1] * scale * mask[i + 1]); uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
} }
} }
...@@ -357,9 +441,13 @@ void launch_dropout_grad(T* vals_out, ...@@ -357,9 +441,13 @@ void launch_dropout_grad(T* vals_out,
float ratio, float ratio,
cudaStream_t stream) cudaStream_t stream)
{ {
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / 2), DS_CUDA_NUM_THREADS, 0, stream>>>( dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
total_count, scale, vals, vals_out, mask); DS_CUDA_NUM_THREADS,
0,
stream>>>(total_count, scale, vals, vals_out, mask);
} }
template void launch_dropout_grad(float*, template void launch_dropout_grad(float*,
const float* vals, const float* vals,
...@@ -374,7 +462,8 @@ template void launch_dropout_grad(__half*, ...@@ -374,7 +462,8 @@ template void launch_dropout_grad(__half*,
float ratio, float ratio,
cudaStream_t stream); cudaStream_t stream);
__global__ void dropout_kernel(const int dim, __global__ void dropout_kernel(const int N,
const int dim,
const float ratio, const float ratio,
const float* bias, const float* bias,
float* Xdata, float* Xdata,
...@@ -383,26 +472,27 @@ __global__ void dropout_kernel(const int dim, ...@@ -383,26 +472,27 @@ __global__ void dropout_kernel(const int dim,
{ {
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state); curand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata); float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias); const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{ {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
uint8_t m[4]; uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
int i = blockIdx.x * dim + tid * 4; float4 x_data = Xdata_cast[j];
float4 x_data = Xdata_cast[idx];
float4 b_data = bias_cast[tid]; float4 b_data = bias_cast[tid];
x_data.x += b_data.x; x_data.x += b_data.x;
...@@ -415,16 +505,26 @@ __global__ void dropout_kernel(const int dim, ...@@ -415,16 +505,26 @@ __global__ void dropout_kernel(const int dim,
x_data.z = x_data.z * scale * m[2]; x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3]; x_data.w = x_data.w * scale * m[3];
mask[i] = (uint8_t)m[0]; mask_32[j] = m_32;
mask[i + 1] = (uint8_t)m[1]; Xdata_cast[j] = x_data;
mask[i + 2] = (uint8_t)m[2]; }
mask[i + 3] = (uint8_t)m[3]; int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
Xdata_cast[idx] = x_data; if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[threadIdx.x % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
} }
} }
__global__ void dropout_kernel(const int dim, __global__ void dropout_kernel(const int N,
const int dim,
const float ratio, const float ratio,
const __half* bias, const __half* bias,
__half* Xdata, __half* Xdata,
...@@ -433,17 +533,17 @@ __global__ void dropout_kernel(const int dim, ...@@ -433,17 +533,17 @@ __global__ void dropout_kernel(const int dim,
{ {
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state); curand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata); float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias); const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{ {
int i = blockIdx.x * dim + tid * 4;
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
float2 data_f; float2 data_f;
...@@ -452,7 +552,7 @@ __global__ void dropout_kernel(const int dim, ...@@ -452,7 +552,7 @@ __global__ void dropout_kernel(const int dim,
float2 bias_f; float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f); __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[idx]; data_f = Xdata_cast[j];
bias_f = bias_cast[tid]; bias_f = bias_cast[tid];
float2 data_h_0 = __half22float2(data_h[0]); float2 data_h_0 = __half22float2(data_h[0]);
...@@ -466,7 +566,8 @@ __global__ void dropout_kernel(const int dim, ...@@ -466,7 +566,8 @@ __global__ void dropout_kernel(const int dim,
data_h_1.x += bias_h_1.x; data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y; data_h_1.y += bias_h_1.y;
uint8_t m[4]; // = mask + i; uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
...@@ -484,12 +585,21 @@ __global__ void dropout_kernel(const int dim, ...@@ -484,12 +585,21 @@ __global__ void dropout_kernel(const int dim,
result_h[0] = __float22half2_rn(data_h_0); result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1); result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[idx] = result_f; Xdata_cast[j] = result_f;
mask_32[j] = m_32;
mask[i] = m[0]; }
mask[i + 1] = m[1]; int high_index =
mask[i + 2] = m[2]; ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
mask[i + 3] = m[3]; if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[threadIdx.x % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
} }
} }
...@@ -502,13 +612,18 @@ void launch_dropout(T* out, ...@@ -502,13 +612,18 @@ void launch_dropout(T* out,
float ratio, float ratio,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4); assert(unroll_factor == 4);
dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(dim, ratio, bias, out, mask, seed); dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, bias, out, mask, seed);
} }
template void launch_dropout(float*, template void launch_dropout(float*,
...@@ -526,7 +641,8 @@ template void launch_dropout(__half*, ...@@ -526,7 +641,8 @@ template void launch_dropout(__half*,
float ratio, float ratio,
cudaStream_t stream); cudaStream_t stream);
__global__ void dropout_kernel(const int dim, __global__ void dropout_kernel(const int N,
const int dim,
const float ratio, const float ratio,
const float* input, const float* input,
const float* residual, const float* residual,
...@@ -537,31 +653,34 @@ __global__ void dropout_kernel(const int dim, ...@@ -537,31 +653,34 @@ __global__ void dropout_kernel(const int dim,
{ {
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state); curand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out); float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias); const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual); const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input); const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{ {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
uint8_t m[4];
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
// int bid = k * blockDim.x + tid; float4 out_data;
int i = blockIdx.x * dim + tid * 4;
float4 out_data = out_cast[idx];
float4 b_data = bias_cast[tid]; float4 b_data = bias_cast[tid];
float4 res_data = residual_cast[idx]; float4 res_data = residual_cast[j];
float4 inp_data = input_cast[idx]; float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x); out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y); out_data.y = (b_data.y + inp_data.y);
...@@ -578,16 +697,29 @@ __global__ void dropout_kernel(const int dim, ...@@ -578,16 +697,29 @@ __global__ void dropout_kernel(const int dim,
out_data.z += res_data.z; out_data.z += res_data.z;
out_data.w += res_data.w; out_data.w += res_data.w;
mask[i] = m[0]; mask_32[j] = m_32;
mask[i + 1] = m[1]; out_cast[j] = out_data;
mask[i + 2] = m[2]; }
mask[i + 3] = m[3]; int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
out_cast[idx] = out_data; if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[threadIdx.x % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
} }
} }
__global__ void dropout_kernel(const int dim, __global__ void dropout_kernel(const int N,
const int dim,
const float ratio, const float ratio,
const __half* input, const __half* input,
const __half* residual, const __half* residual,
...@@ -598,19 +730,20 @@ __global__ void dropout_kernel(const int dim, ...@@ -598,19 +730,20 @@ __global__ void dropout_kernel(const int dim,
{ {
const float scale = 1. / (1. - ratio); const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x % (dim / unroll_factor);
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state); curand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out); float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias); const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual); const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input); const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{ {
int i = blockIdx.x * dim + tid * 4;
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
float2 data_f; float2 data_f;
...@@ -625,10 +758,9 @@ __global__ void dropout_kernel(const int dim, ...@@ -625,10 +758,9 @@ __global__ void dropout_kernel(const int dim,
float2 input_f; float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f); __half2* input_h = reinterpret_cast<__half2*>(&input_f);
data_f = out_cast[idx];
bias_f = bias_cast[tid]; bias_f = bias_cast[tid];
residual_f = residual_cast[idx]; residual_f = residual_cast[j];
input_f = input_cast[idx]; input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]); float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]); float2 data_h_1 = __half22float2(data_h[1]);
...@@ -647,7 +779,8 @@ __global__ void dropout_kernel(const int dim, ...@@ -647,7 +779,8 @@ __global__ void dropout_kernel(const int dim,
data_h_1.x = (bias_h_1.x + input_h_1.x); data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y); data_h_1.y = (bias_h_1.y + input_h_1.y);
uint8_t m[4]; // = mask + i; uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
...@@ -670,12 +803,24 @@ __global__ void dropout_kernel(const int dim, ...@@ -670,12 +803,24 @@ __global__ void dropout_kernel(const int dim,
result_h[0] = __float22half2_rn(data_h_0); result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1); result_h[1] = __float22half2_rn(data_h_1);
out_cast[idx] = result_f; out_cast[j] = result_f;
mask_32[j] = m_32;
mask[i] = m[0]; }
mask[i + 1] = m[1]; int high_index =
mask[i + 2] = m[2]; ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
mask[i + 3] = m[3]; if (N > high_index) {
float4 rand = curand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[threadIdx.x % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
} }
} }
...@@ -690,14 +835,17 @@ void launch_dropout(T* out, ...@@ -690,14 +835,17 @@ void launch_dropout(T* out,
float ratio, float ratio,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 grid_dim(batch); // DS_GET_BLOCKS(total_count/4); assert(unroll_factor == 4);
dim3 block_dim(dim / 4); // DS_CUDA_NUM_THREADS;
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>( dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
dim, ratio, input, residual, bias, out, mask, seed); total_count, dim, ratio, input, residual, bias, out, mask, seed);
} }
template void launch_dropout(float*, template void launch_dropout(float*,
......
...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, ...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
} }
__global__ void fused_add2_kernel(float* out, __global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
const float* inp1,
const float* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1); const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2); const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out); float4* out_4 = reinterpret_cast<float4*>(out);
float4 val; CUDA_1D_KERNEL_LOOP(j, N)
float4 inp1_reg = inp1_4[row * row_stride + id]; {
float4 inp2_reg = inp2_4[row * row_stride + id]; float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x; val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y; val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z; val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w; val.w = inp1_reg.w + inp2_reg.w;
out_4[row * row_stride + id] = val; out_4[j] = val;
}
} }
__global__ void fused_add2_kernel(__half* out, __global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
const __half* inp1,
const __half* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
float2 inp1_4; float2 inp1_4;
float2 inp2_4; float2 inp2_4;
...@@ -129,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out, ...@@ -129,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out,
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1); const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2); const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
inp1_4 = inp1_arr[row * row_stride + id]; CUDA_1D_KERNEL_LOOP(j, N)
inp2_4 = inp2_arr[row * row_stride + id]; {
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]); float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]); float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]); float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]); float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x; inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y; inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x; inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y; inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f; float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f); __half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0); val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1); val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out); float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f; out_4[j] = val_f;
}
} }
template <> template <>
...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out, ...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
template <> template <>
...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out, ...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
__global__ void fused_add3_kernel(float* out, __global__ void fused_add3_kernel(float* out,
......
...@@ -5,22 +5,14 @@ namespace cg = cooperative_groups; ...@@ -5,22 +5,14 @@ namespace cg = cooperative_groups;
/* /*
Fused bias add, residual (elementwise) add, and normalization layer. Fused bias add, residual (elementwise) add, and normalization layer.
Unlike the GELU, which doesn't require template parameters, this layer does since it
does rely fairly heavily on unrolling loops. Currently, I exclude bounds checks and
assume that the number of elements is a multiple of a power of 2. Default behavior
for our purposes uses 256 threads for floats, and 128 threads for __half. This restriction
is a result of using the shift parameter to perform the minimum number of register file
shuffles necessary, which requires the number of threads in the secondary reduction to
be 1, 2, 4, 8, 16, or 32. The number of threads here corresponds to the number of complete
warps in the threadblock.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions. For specific launch constraints, see the launch functions.
*/ */
template <int row_stride, int iterations> #define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals, __global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual, const float* residual,
const float* gamma, const float* gamma,
...@@ -29,26 +21,37 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -29,26 +21,37 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
float* vars, float* vars,
float* means) float* means,
int row_stride)
{ {
constexpr int iteration_stride = row_stride / iterations; int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int gid = id / 32; int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
float vals_arr[iterations]; residual += (row * row_stride);
__shared__ float shr[iteration_stride >> 5]; vals += (row * row_stride);
float sum = 0.f; float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[row * row_stride + i * iteration_stride + id]; vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i]; sum += vals_arr[i];
} }
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
...@@ -70,7 +73,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -70,7 +73,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (g.thread_rank() == 0) means[row] = mean; if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f; float variance = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
variance += (vals_arr[i] - mean) * (vals_arr[i] - mean); vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
} }
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
...@@ -92,15 +96,20 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -92,15 +96,20 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (training) if (training)
if (g.thread_rank() == 0) vars[row] = variance; if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance); vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] = vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[row * row_stride + i * iteration_stride + id] = vals_arr[i]; vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
} }
} }
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(__half* vals, __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual, const __half* residual,
const __half* gamma, const __half* gamma,
...@@ -109,10 +118,12 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -109,10 +118,12 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
__half* vars, __half* vars,
__half* means) __half* means,
int row_stride)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations; int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
...@@ -121,20 +132,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -121,20 +132,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
int id = threadIdx.x; int id = threadIdx.x;
int gid = id >> 5; int gid = id >> 5;
__half2 vals_arr[iterations]; float2 vals_f[NORM_REG];
float2 vals_f[iterations]; __shared__ float shr[MAX_WARP_NUM];
__shared__ float shr[iteration_stride >> 5];
__half2* vals_cast = reinterpret_cast<__half2*>(vals); __half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual); const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f; float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]); vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x; sum += vals_f[i].x;
sum += vals_f[i].y; sum += vals_f[i].y;
} }
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
...@@ -154,8 +174,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -154,8 +174,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
float variance = 0.f; float variance = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
variance += (vals_f[i].x - mean) * (vals_f[i].x - mean); vals_f[i].x -= mean;
variance += (vals_f[i].y - mean) * (vals_f[i].y - mean); vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
} }
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
...@@ -175,7 +197,6 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -175,7 +197,6 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
variance /= (row_stride * 2); variance /= (row_stride * 2);
variance += epsilon; variance += epsilon;
__half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance); __half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma); const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta); const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
...@@ -184,13 +205,19 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -184,13 +205,19 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
vars[row] = __float2half(variance); vars[row] = __float2half(variance);
means[row] = __float2half(mean); means[row] = __float2half(mean);
} }
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = __float22half2_rn(vals_f[i]); __half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h); vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] + vals_arr =
beta_cast[i * iteration_stride + id]; vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i]; vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
} }
#endif #endif
} }
...@@ -223,33 +250,21 @@ void launch_bias_residual_layer_norm<float>(float* vals, ...@@ -223,33 +250,21 @@ void launch_bias_residual_layer_norm<float>(float* vals,
float* vars, float* vars,
float* means) float* means)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(batch_size); dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads); dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
if (hidden_dim == 768) vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else
throw std::runtime_error("Unsupport hidden_dim.");
} }
template <> template <>
...@@ -266,35 +281,25 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, ...@@ -266,35 +281,25 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
__half* vars, __half* vars,
__half* means) __half* means)
{ {
constexpr int threads = 128; int threads = 128;
dim3 grid_dim(batch_size); dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. if (hidden_dim > 8192 && hidden_dim <= 16384)
if (hidden_dim == 768) threads <<= 1;
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>( else if (hidden_dim > 16384 && hidden_dim <= 32768)
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); threads <<= 2;
else if (hidden_dim == 512) else if (hidden_dim > 32768 && hidden_dim <= 65536)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>( threads <<= 3;
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means); else if (hidden_dim > 65536)
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
} }
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(float* vals, __global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual, const float* residual,
const float* gamma, const float* gamma,
...@@ -302,9 +307,11 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -302,9 +307,11 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
float* vars) float* vars,
int row_stride)
{ {
constexpr int iteration_stride = row_stride / iterations; int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
...@@ -313,15 +320,24 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -313,15 +320,24 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
int id = threadIdx.x; int id = threadIdx.x;
int gid = id / 32; int gid = id / 32;
float vals_arr[iterations]; float vals_arr[NORM_REG];
__shared__ float shr[iteration_stride >> 5]; __shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f; float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[row * row_stride + i * iteration_stride + id]; vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i]; sum += vals_arr[i];
} }
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
...@@ -341,7 +357,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -341,7 +357,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
float mean = sum / row_stride; float mean = sum / row_stride;
float variance = 0.f; float variance = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
variance += (vals_arr[i] - mean) * (vals_arr[i] - mean); vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
} }
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
...@@ -363,15 +380,20 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -363,15 +380,20 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
if (training) if (training)
if (g.thread_rank() == 0) vars[row] = variance; if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = (vals_arr[i] - mean) * rsqrtf(variance); vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] = vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[row * row_stride + i * iteration_stride + id] = vals_arr[i]; vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
} }
} }
template <int row_stride, int iterations>
__global__ void fused_bias_residual_layer_norm(__half* vals, __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual, const __half* residual,
const __half* gamma, const __half* gamma,
...@@ -379,10 +401,13 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -379,10 +401,13 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
__half* vars) __half* vars,
int row_stride)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations;
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
...@@ -391,20 +416,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -391,20 +416,29 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
int id = threadIdx.x; int id = threadIdx.x;
int gid = id >> 5; int gid = id >> 5;
__half2 vals_arr[iterations]; float2 vals_f[NORM_REG];
float2 vals_f[iterations]; __shared__ float shr[MAX_WARP_NUM];
__shared__ float shr[iteration_stride >> 5];
__half2* vals_cast = reinterpret_cast<__half2*>(vals); __half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual); const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f; float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[row * row_stride + i * iteration_stride + id]); vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x; sum += vals_f[i].x;
sum += vals_f[i].y; sum += vals_f[i].y;
} }
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
...@@ -424,8 +458,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -424,8 +458,10 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
float variance = 0.f; float variance = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
variance += (vals_f[i].x - mean) * (vals_f[i].x - mean); vals_f[i].x -= mean;
variance += (vals_f[i].y - mean) * (vals_f[i].y - mean); vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
} }
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
...@@ -445,19 +481,25 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -445,19 +481,25 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
variance /= (row_stride * 2); variance /= (row_stride * 2);
variance += epsilon; variance += epsilon;
__half2 mean_h = __float2half2_rn(mean);
__half2 variance_h = __float2half2_rn(variance); __half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma); const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta); const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance); if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr[i] = __float22half2_rn(vals_f[i]); __half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr[i] = (vals_arr[i] - mean_h) * h2rsqrt(variance_h); vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr[i] = vals_arr[i] * gamma_cast[i * iteration_stride + id] + vals_arr =
beta_cast[i * iteration_stride + id]; vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[row * row_stride + i * iteration_stride + id] = vals_arr[i]; vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
} }
#endif #endif
} }
...@@ -503,33 +545,23 @@ void launch_bias_residual_layer_norm<float>(float* vals, ...@@ -503,33 +545,23 @@ void launch_bias_residual_layer_norm<float>(float* vals,
bool training, bool training,
float* vars) float* vars)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(batch_size); dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>( if (hidden_dim > 16384 && hidden_dim <= 32768)
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); threads <<= 1;
else if (hidden_dim == 512) else if (hidden_dim > 32768 && hidden_dim <= 65536)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>( threads <<= 2;
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim > 65536)
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
} }
template <> template <>
...@@ -545,32 +577,24 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, ...@@ -545,32 +577,24 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
bool training, bool training,
__half* vars) __half* vars)
{ {
constexpr int threads = 128; int threads = 128;
dim3 grid_dim(batch_size); dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>( if (hidden_dim > 8192 && hidden_dim <= 16384)
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); threads <<= 1;
else if (hidden_dim == 512) else if (hidden_dim > 16384 && hidden_dim <= 32768)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>( threads <<= 2;
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars); else if (hidden_dim > 32768 && hidden_dim <= 65536)
else if (hidden_dim == 1024) threads <<= 3;
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>( else if (hidden_dim > 65536)
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
} }
/* Normalize Gamma & Betta gradients /* Normalize Gamma & Betta gradients
...@@ -713,17 +737,17 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad, ...@@ -713,17 +737,17 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/ */
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const float* out_grad, __global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat, const float* vals_hat,
const float* gamma, const float* gamma,
const float* betta, const float* betta,
const float* vars, const float* vars,
float* inp_grad, float* inp_grad,
bool invertible) bool invertible,
int row_stride)
{ {
constexpr int iterations = row_stride / THREADS; int iteration_stride = blockDim.x;
constexpr int iteration_stride = THREADS; // row_stride / iterations; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -731,21 +755,34 @@ __global__ void LayerNormBackward2(const float* out_grad, ...@@ -731,21 +755,34 @@ __global__ void LayerNormBackward2(const float* out_grad,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[iterations]; out_grad += (row * row_stride);
float vals_hat_arr[iterations]; vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id]; float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat[row * row_stride + i * iteration_stride + id] - vals_hat_arr[i] =
betta[i * iteration_stride + id]) / (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg gamma_reg
: vals_hat[row * row_stride + i * iteration_stride + id]); : vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
} }
float var_reg = vars[row]; float var_reg = vars[row];
...@@ -795,21 +832,22 @@ __global__ void LayerNormBackward2(const float* out_grad, ...@@ -795,21 +832,22 @@ __global__ void LayerNormBackward2(const float* out_grad,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= row_stride; sum /= row_stride;
for (int i = 0; i < iterations; i++) iterations = row_stride / iteration_stride;
inp_grad[row * row_stride + i * iteration_stride + id] = (vals_arr[i] - sum); for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
} }
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const __half* out_grad, __global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat, const __half* vals_hat,
const __half* gamma, const __half* gamma,
const __half* betta, const __half* betta,
const __half* vars, const __half* vars,
__half* inp_grad, __half* inp_grad,
bool invertible) bool invertible,
int row_stride)
{ {
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations; int iteration_stride = blockDim.x;
constexpr int iterations = row_stride / iteration_stride; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -817,30 +855,43 @@ __global__ void LayerNormBackward2(const __half* out_grad, ...@@ -817,30 +855,43 @@ __global__ void LayerNormBackward2(const __half* out_grad,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; __shared__ float partialSum[MAX_WARP_NUM];
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations]; __half2 vals_arr[NORM_REG];
float2 vals_arr_f[iterations]; float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[iterations]; __half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad); const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat); const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma); const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr); const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id]; __half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat_h[row * row_stride + i * iteration_stride + id] - vals_hat_arr[i] =
betta_h[i * iteration_stride + id]) / (invertible
gamma_reg ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
: vals_hat_h[row * row_stride + i * iteration_stride + id]); gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
} }
__half var_h = vars[row]; __half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h); __half2 var_reg = __halves2half2(var_h, var_h);
...@@ -903,12 +954,20 @@ __global__ void LayerNormBackward2(const __half* out_grad, ...@@ -903,12 +954,20 @@ __global__ void LayerNormBackward2(const __half* out_grad,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= (2 * row_stride); sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum; vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum; vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]); __half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = temp; inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
} }
} }
...@@ -926,7 +985,7 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -926,7 +985,7 @@ void launch_layerNorm_backward<float>(const float* out_grad,
bool invertible, bool invertible,
const float* betta) const float* betta)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -935,28 +994,18 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -935,28 +994,18 @@ void launch_layerNorm_backward<float>(const float* out_grad,
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768) if (hidden_dim > 16384 && hidden_dim <= 32768)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible); else if (hidden_dim > 32768 && hidden_dim <= 65536)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 65536)
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
} }
template <> template <>
...@@ -973,7 +1022,7 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -973,7 +1022,7 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
bool invertible, bool invertible,
const __half* betta) const __half* betta)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -982,28 +1031,20 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -982,28 +1031,20 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768) if (hidden_dim > 8192 && hidden_dim <= 16384)
LayerNormBackward2<384><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible); else if (hidden_dim > 16384 && hidden_dim <= 32768)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2<256><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 32768 && hidden_dim <= 65536)
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible); threads <<= 3;
else if (hidden_dim == 1024) else if (hidden_dim > 65536)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
} }
/* Backward Normalize (Input-Gradient) /* Backward Normalize (Input-Gradient)
...@@ -1012,16 +1053,16 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -1012,16 +1053,16 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
* We do the backward using the input (X) * We do the backward using the input (X)
*/ */
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const float* out_grad, __global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals, const float* X_vals,
const float* gamma, const float* gamma,
const float* vars, const float* vars,
const float* means, const float* means,
float* inp_grad) float* inp_grad,
int row_stride)
{ {
constexpr int iterations = row_stride / THREADS; int iteration_stride = blockDim.x;
constexpr int iteration_stride = THREADS; // row_stride / iterations; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1029,25 +1070,35 @@ __global__ void LayerNormBackward2(const float* out_grad, ...@@ -1029,25 +1070,35 @@ __global__ void LayerNormBackward2(const float* out_grad,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[iterations]; out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id]; float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; vals_arr[i] *= gamma_reg;
} }
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row]; float var_reg = vars[row];
float mean_reg = means[row]; float mean_reg = means[row];
float sum = 0; float sum = 0;
float xu[iterations]; float xu[NORM_REG];
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[row * row_stride + i * iteration_stride + id] - mean_reg); xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i]; sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg); vals_arr[i] *= rsqrtf(var_reg);
} }
...@@ -1092,20 +1143,21 @@ __global__ void LayerNormBackward2(const float* out_grad, ...@@ -1092,20 +1143,21 @@ __global__ void LayerNormBackward2(const float* out_grad,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= row_stride; sum /= row_stride;
for (int i = 0; i < iterations; i++) iterations = row_stride / iteration_stride;
inp_grad[row * row_stride + i * iteration_stride + id] = (vals_arr[i] - sum); for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
} }
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2(const __half* out_grad, __global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals, const __half* X_vals,
const __half* gamma, const __half* gamma,
const __half* vars, const __half* vars,
const __half* means, const __half* means,
__half* inp_grad) __half* inp_grad,
int row_stride)
{ {
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations; int iteration_stride = blockDim.x;
constexpr int iterations = row_stride / iteration_stride; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1113,35 +1165,44 @@ __global__ void LayerNormBackward2(const __half* out_grad, ...@@ -1113,35 +1165,44 @@ __global__ void LayerNormBackward2(const __half* out_grad,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[iterations]; __half2 vals_arr[NORM_REG];
float2 vals_arr_f[iterations]; float2 vals_arr_f[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad); const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals); const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma); inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id]; __half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma vals_arr[i] *= gamma_reg; // out_grad * gamma
} }
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
iterations++;
}
__half mean_h = means[row]; __half mean_h = means[row];
__half var_h = vars[row]; __half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h); __half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h); __half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[iterations]; __half2 xu[NORM_REG];
float sum = 0.f; float sum = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_h[row * row_stride + i * iteration_stride + id] - mean_reg); xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]); __half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h); float2 result_f = __half22float2(result_h);
sum += result_f.x; sum += result_f.x;
...@@ -1198,11 +1259,18 @@ __global__ void LayerNormBackward2(const __half* out_grad, ...@@ -1198,11 +1259,18 @@ __global__ void LayerNormBackward2(const __half* out_grad,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= (2 * row_stride); sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum; vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum; vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]); __half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = temp; inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
} }
} }
...@@ -1219,7 +1287,7 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -1219,7 +1287,7 @@ void launch_layerNorm_backward<float>(const float* out_grad,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1228,28 +1296,17 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -1228,28 +1296,17 @@ void launch_layerNorm_backward<float>(const float* out_grad,
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768) if (hidden_dim > 16384 && hidden_dim <= 32768)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad, X_data, gamma, vars, means, inp_grad); else if (hidden_dim > 32768 && hidden_dim <= 65536)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 65536)
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
} }
template <> template <>
...@@ -1265,7 +1322,7 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -1265,7 +1322,7 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1274,28 +1331,19 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -1274,28 +1331,19 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768) if (hidden_dim > 8192 && hidden_dim <= 16384)
LayerNormBackward2<384><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad, X_data, gamma, vars, means, inp_grad); else if (hidden_dim > 16384 && hidden_dim <= 32768)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2<256><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 32768 && hidden_dim <= 65536)
out_grad, X_data, gamma, vars, means, inp_grad); threads <<= 3;
else if (hidden_dim == 1024) else if (hidden_dim > 65536)
LayerNormBackward2<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
} }
template <typename T> template <typename T>
...@@ -1421,7 +1469,6 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, ...@@ -1421,7 +1469,6 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
} }
} }
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const float* out_grad1, __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2, const float* out_grad2,
const float* vals_hat, const float* vals_hat,
...@@ -1429,10 +1476,11 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1, ...@@ -1429,10 +1476,11 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* betta, const float* betta,
const float* vars, const float* vars,
float* inp_grad, float* inp_grad,
bool invertible) bool invertible,
int row_stride)
{ {
constexpr int iterations = row_stride / THREADS; int iteration_stride = blockDim.x;
constexpr int iteration_stride = THREADS; // row_stride / iterations; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1440,21 +1488,35 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1, ...@@ -1440,21 +1488,35 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[iterations]; out_grad1 += (row * row_stride);
float vals_hat_arr[iterations]; out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id]; float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = (invertible ? (vals_hat[row * row_stride + i * iteration_stride + id] - vals_hat_arr[i] =
betta[i * iteration_stride + id]) / (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg gamma_reg
: vals_hat[row * row_stride + i * iteration_stride + id]); : vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
} }
float var_reg = vars[row]; float var_reg = vars[row];
...@@ -1503,12 +1565,14 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1, ...@@ -1503,12 +1565,14 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= row_stride; sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] = inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[row * row_stride + i * iteration_stride + id]; (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
} }
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2, const __half* out_grad2,
const __half* vals_hat, const __half* vals_hat,
...@@ -1516,10 +1580,11 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1516,10 +1580,11 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* betta, const __half* betta,
const __half* vars, const __half* vars,
__half* inp_grad, __half* inp_grad,
bool invertible) bool invertible,
int row_stride)
{ {
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations; int iteration_stride = blockDim.x;
constexpr int iterations = row_stride / iteration_stride; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1527,13 +1592,12 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1527,13 +1592,12 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; __shared__ float partialSum[MAX_WARP_NUM];
__shared__ float partialSum[warp_num];
__half2 vals_arr[iterations]; __half2 vals_arr[NORM_REG];
float2 vals_arr_f[iterations]; float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[iterations]; __half2 vals_hat_arr[NORM_REG];
// float2 result[iterations]; // float2 result[iterations];
...@@ -1542,18 +1606,33 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1542,18 +1606,33 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2); const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat); const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma); const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr); const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id]; __half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = (invertible ? (vals_hat_h[row * row_stride + i * iteration_stride + id] - vals_hat_arr[i] =
betta_h[i * iteration_stride + id]) / (invertible
gamma_reg ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
: vals_hat_h[row * row_stride + i * iteration_stride + id]); gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
} }
__half var_h = vars[row]; __half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h); __half2 var_reg = __halves2half2(var_h, var_h);
...@@ -1615,13 +1694,20 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1615,13 +1694,20 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= (2 * row_stride); sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum; vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum; vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]); __half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
temp + out_grad_h2[row * row_stride + i * iteration_stride + id]; }
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
} }
} }
...@@ -1640,7 +1726,7 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -1640,7 +1726,7 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
bool invertible, bool invertible,
const float* betta) const float* betta)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1648,28 +1734,17 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -1648,28 +1734,17 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768) if (hidden_dim > 16384 && hidden_dim <= 32768)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible); else if (hidden_dim > 32768 && hidden_dim <= 65536)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 65536)
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
} }
template <> template <>
...@@ -1687,7 +1762,7 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -1687,7 +1762,7 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
bool invertible, bool invertible,
const __half* betta) const __half* betta)
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1696,28 +1771,19 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -1696,28 +1771,19 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768) if (hidden_dim > 8192 && hidden_dim <= 16384)
LayerNormBackward2_fused_add<384><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible); else if (hidden_dim > 16384 && hidden_dim <= 32768)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2_fused_add<256><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 32768 && hidden_dim <= 65536)
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible); threads <<= 3;
else if (hidden_dim == 1024) else if (hidden_dim > 65536)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
} }
/* Backward Normalize (Input-Gradient) /* Backward Normalize (Input-Gradient)
...@@ -1726,17 +1792,17 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -1726,17 +1792,17 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
* We do the backward using the input (X) * We do the backward using the input (X)
*/ */
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const float* out_grad1, __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2, const float* out_grad2,
const float* X_vals, const float* X_vals,
const float* gamma, const float* gamma,
const float* vars, const float* vars,
const float* means, const float* means,
float* inp_grad) float* inp_grad,
int row_stride)
{ {
constexpr int iterations = row_stride / THREADS; int iteration_stride = blockDim.x;
constexpr int iteration_stride = THREADS; // row_stride / iterations; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1744,25 +1810,37 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1, ...@@ -1744,25 +1810,37 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[iterations]; float vals_arr[NORM_REG];
float vals_hat_arr[iterations]; float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id]; float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[row * row_stride + i * iteration_stride + id]; vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
} }
float var_reg = vars[row]; float var_reg = vars[row];
float mean_reg = means[row]; float mean_reg = means[row];
float sum = 0; float sum = 0;
float xu[iterations]; float xu[NORM_REG];
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg); xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i]; sum += vals_arr[i] * xu[i];
...@@ -1809,23 +1887,25 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1, ...@@ -1809,23 +1887,25 @@ __global__ void LayerNormBackward2_fused_add(const float* out_grad1,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= row_stride; sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) for (int i = 0; i < iterations; i++)
inp_grad[row * row_stride + i * iteration_stride + id] = inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[row * row_stride + i * iteration_stride + id]; (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
; if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
} }
template <int row_stride> // Hidden_Dim
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2, const __half* out_grad2,
const __half* X_vals, const __half* X_vals,
const __half* gamma, const __half* gamma,
const __half* vars, const __half* vars,
const __half* means, const __half* means,
__half* inp_grad) __half* inp_grad,
int row_stride)
{ {
constexpr int iteration_stride = THREADS / 2; // row_stride / iterations; int iteration_stride = blockDim.x;
constexpr int iterations = row_stride / iteration_stride; int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
...@@ -1833,35 +1913,46 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1833,35 +1913,46 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
int row = blockIdx.x; int row = blockIdx.x;
int id = threadIdx.x; int id = threadIdx.x;
int wid = id / WARP_SIZE; int wid = id / WARP_SIZE;
constexpr int warp_num = int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
(iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[warp_num]; __shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[iterations]; __half2 vals_arr[NORM_REG];
float2 vals_arr_f[iterations]; float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[iterations]; __half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1); const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2); const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals); const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma); out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll #pragma unroll
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id]; __half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[row * row_stride + i * iteration_stride + id]; vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[row * row_stride + i * iteration_stride + id]; vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
} }
__half mean_h = means[row]; __half mean_h = means[row];
__half var_h = vars[row]; __half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h); __half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h); __half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[iterations]; __half2 xu[NORM_REG];
float sum = 0.f; float sum = 0.f;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
...@@ -1922,12 +2013,18 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1, ...@@ -1922,12 +2013,18 @@ __global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
sum /= (2 * row_stride); sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum; vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum; vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]); __half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[row * row_stride + i * iteration_stride + id] = inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
temp + out_grad_h2[row * row_stride + i * iteration_stride + id]; }
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
} }
} }
...@@ -1945,7 +2042,7 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -1945,7 +2042,7 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1954,28 +2051,17 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -1954,28 +2051,17 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads);
if (hidden_dim == 768) if (hidden_dim > 16384 && hidden_dim <= 32768)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad); else if (hidden_dim > 32768 && hidden_dim <= 65536)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 65536)
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1024)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<1536><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<2048><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<2560><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
} }
template <> template <>
...@@ -1992,7 +2078,7 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -1992,7 +2078,7 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -2001,26 +2087,17 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -2001,26 +2087,17 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch); dim3 grid_dim2(batch);
dim3 block_dim2(threads / 2);
if (hidden_dim == 768) if (hidden_dim > 8192 && hidden_dim <= 16384)
LayerNormBackward2_fused_add<384><<<grid_dim2, block_dim2, 0, stream[1]>>>( threads <<= 1;
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad); else if (hidden_dim > 16384 && hidden_dim <= 32768)
else if (hidden_dim == 512) threads <<= 2;
LayerNormBackward2_fused_add<256><<<grid_dim2, block_dim2, 0, stream[1]>>>( else if (hidden_dim > 32768 && hidden_dim <= 65536)
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad); threads <<= 3;
else if (hidden_dim == 1024) else if (hidden_dim > 65536)
LayerNormBackward2_fused_add<512><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 1536)
LayerNormBackward2_fused_add<768><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2048)
LayerNormBackward2_fused_add<1024><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else if (hidden_dim == 2560)
LayerNormBackward2_fused_add<1280><<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad);
else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
} }
...@@ -53,27 +53,33 @@ void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, ...@@ -53,27 +53,33 @@ void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols,
} }
template <typename T> template <typename T>
__global__ void transform_0213(T* output, const T* vals, int hidden_dim, int seq_length, int heads); __global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <> template <>
__global__ void transform_0213<float>(float* output, __global__ void transform_0213<float>(float* output,
const float* vals, const float* vals,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 4; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length; int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals); const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output); float4* output_vec = reinterpret_cast<float4*>(output);
...@@ -87,22 +93,23 @@ __global__ void transform_0213<__half>(__half* output, ...@@ -87,22 +93,23 @@ __global__ void transform_0213<__half>(__half* output,
const __half* vals, const __half* vals,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length; int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1]; float4 vals_arr[1];
...@@ -123,10 +130,13 @@ void launch_transform_0213<float>(float* output, ...@@ -123,10 +130,13 @@ void launch_transform_0213<float>(float* output,
int heads, int heads,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 block_dim(hidden_dim / heads / 4, heads); hidden_dim >>= 2;
dim3 grid_dim(batch_size, seq_length); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
transform_0213<float> transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads); <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
} }
template <> template <>
...@@ -138,10 +148,12 @@ void launch_transform_0213<__half>(__half* output, ...@@ -138,10 +148,12 @@ void launch_transform_0213<__half>(__half* output,
int heads, int heads,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 block_dim(hidden_dim / heads / 8, heads); hidden_dim >>= 3;
dim3 grid_dim(batch_size, seq_length); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
transform_0213<__half> transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads); <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
} }
// Bias add // Bias add
...@@ -151,7 +163,8 @@ __global__ void bias_add_transform_0213(T* output, ...@@ -151,7 +163,8 @@ __global__ void bias_add_transform_0213(T* output,
const T* bias, const T* bias,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads); int heads,
int head_ext);
template <> template <>
__global__ void bias_add_transform_0213<float>(float* output, __global__ void bias_add_transform_0213<float>(float* output,
...@@ -159,28 +172,29 @@ __global__ void bias_add_transform_0213<float>(float* output, ...@@ -159,28 +172,29 @@ __global__ void bias_add_transform_0213<float>(float* output,
const float* bias, const float* bias,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 4; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length; int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z; // Hidden count int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals); const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias); const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output); float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * gridDim.z + cnt * d1_stride + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * gridDim.z + d2 * d2_stride + d3]; d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs; float4 outputs;
...@@ -202,14 +216,73 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -202,14 +216,73 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
const __half* bias, const __half* bias,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
...@@ -237,6 +310,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -237,6 +310,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int iter_index = cnt * d1_stride + d2 * d2_stride + d3; int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index]; bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index; int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id]; vals_arr[0] = vals_vec[input_offset + iter_id];
...@@ -255,6 +330,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -255,6 +330,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count; int iter_row = (iter * iteration_stride) + head_count;
int iter_offset = int iter_offset =
...@@ -277,10 +354,14 @@ void launch_bias_add_transform_0213<float>(float* output, ...@@ -277,10 +354,14 @@ void launch_bias_add_transform_0213<float>(float* output,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 block_dim(hidden_dim / heads / 4, heads); hidden_dim >>= 2;
dim3 grid_dim(batch_size, seq_length, trans_count); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
bias_add_transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads); dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
bias_add_transform_0213<float><<<grid_dim, block_dim, 0, stream>>>(
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} }
template <> template <>
...@@ -294,32 +375,47 @@ void launch_bias_add_transform_0213<__half>(__half* output, ...@@ -294,32 +375,47 @@ void launch_bias_add_transform_0213<__half>(__half* output,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 block_dim(hidden_dim / heads / 8, heads, trans_count); hidden_dim >>= 3;
dim3 grid_dim(batch_size, seq_length / 2); if (hidden_dim > 128 || hidden_dim < 16) {
bias_add_transform_0213<__half> int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads); dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
bias_add_transform_0213<__half><<<grid_dim, block_dim, 0, stream>>>(
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
bias_add_transform_0213_v2<<<grid_dim, block_dim, 0, stream>>>(
output, vals, bias, hidden_dim, seq_length, heads);
}
} }
template <typename T> template <typename T>
__global__ void transform4d_0213(T* out, const T* in, int heads, int seq_length, int hidden_dim); __global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <> template <>
__global__ void transform4d_0213<float>(float* out, __global__ void transform4d_0213<float>(float* out,
const float* in, const float* in,
int heads, int heads,
int seq_length, int seq_length,
int hidden_dim) int hidden_dim,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads; int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim / 4; int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length + blockDim.y - 1) / blockDim.y); // Head int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z; int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8) int d3 = threadIdx.x; // Values (groups of 8)
...@@ -340,14 +436,51 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -340,14 +436,51 @@ __global__ void transform4d_0213<__half>(__half* out,
const __half* in, const __half* in,
int heads, int heads,
int seq_length, int seq_length,
int hidden_dim) int hidden_dim,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head int d1 = threadIdx.y; // Head
...@@ -358,11 +491,12 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -358,11 +491,12 @@ __global__ void transform4d_0213<__half>(__half* out,
const float4* in_vec = reinterpret_cast<const float4*>(in); const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out); float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + d1 % 2 * d2_stride; int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1); int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x); int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count; int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride; int iter_offset = (iter_row % blockDim.y) * d2_stride;
...@@ -377,6 +511,7 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -377,6 +511,7 @@ __global__ void transform4d_0213<__half>(__half* out,
int iter_index = cnt * d1_stride + d1 * d2_stride + d3; int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index; int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id]; out_vec[output_offset + iter_id] = in_data[iter_id];
...@@ -395,10 +530,11 @@ void launch_transform4d_0213<float>(float* out, ...@@ -395,10 +530,11 @@ void launch_transform4d_0213<float>(float* out,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 grid_dims(batch_size, heads * ((seq_length + 7) / 8), trans_count); hidden_dim >>= 2;
dim3 block_dims(hidden_dim / heads / 4, 8); dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
transform4d_0213<float> transform4d_0213<float>
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim); <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim, 1);
} }
template <> template <>
...@@ -411,8 +547,17 @@ void launch_transform4d_0213<__half>(__half* out, ...@@ -411,8 +547,17 @@ void launch_transform4d_0213<__half>(__half* out,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 grid_dims(batch_size, seq_length / 2); hidden_dim >>= 3;
dim3 block_dims(hidden_dim / heads / 8, heads, trans_count); if (hidden_dim > 128 || hidden_dim < 16) {
transform4d_0213<__half> int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim); dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
transform4d_0213<__half><<<grid_dims, block_dims, 0, stream>>>(
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
transform4d_0213_v2<<<grid_dims, block_dims, 0, stream>>>(
out, in, heads, seq_length, hidden_dim);
}
} }
...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False): ...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
(3,1024,120,16,24,True,True, 0.05), (3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1), (3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2), (3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
]) # yapf: disable ]) # yapf: disable
def test_backward(batch_size, def test_backward(batch_size,
hidden_size, hidden_size,
......
...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16. # FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[ [
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False), (64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True), (64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False), (8,1024,384,16,3,True,False),
...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
(8,2048,128,32,3,False,True), (8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False), (8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True), (8,2560,128,40,3,False,True),
(8,128,128,2,3,True,False),
(8,128,128,2,3,True,True),
(8,4096,128,64,3,True,True),
(8,8192,128,64,3,False,True),
]) # yapf: disable ]) # yapf: disable
def test_forward(batch_size, def test_forward(batch_size,
hidden_size, hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment