Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
eadbbe09
Commit
eadbbe09
authored
Apr 25, 2021
by
401qingkong
Browse files
push rocm deepspeed v0.3.13
parent
ab5534fc
Changes
155
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
10002 additions
and
1 deletion
+10002
-1
deepspeed/ops/csrc/transformer/hip/cublas_wrappers.hip
deepspeed/ops/csrc/transformer/hip/cublas_wrappers.hip
+199
-0
deepspeed/ops/csrc/transformer/hip/dropout_kernels.hip
deepspeed/ops/csrc/transformer/hip/dropout_kernels.hip
+869
-0
deepspeed/ops/csrc/transformer/hip/ds_transformer_hip.cpp
deepspeed/ops/csrc/transformer/hip/ds_transformer_hip.cpp
+1046
-0
deepspeed/ops/csrc/transformer/hip/gelu_kernels.hip
deepspeed/ops/csrc/transformer/hip/gelu_kernels.hip
+336
-0
deepspeed/ops/csrc/transformer/hip/general_kernels.hip
deepspeed/ops/csrc/transformer/hip/general_kernels.hip
+412
-0
deepspeed/ops/csrc/transformer/hip/normalize_kernels.hip
deepspeed/ops/csrc/transformer/hip/normalize_kernels.hip
+2104
-0
deepspeed/ops/csrc/transformer/hip/softmax_kernels.hip
deepspeed/ops/csrc/transformer/hip/softmax_kernels.hip
+592
-0
deepspeed/ops/csrc/transformer/hip/transform_kernels.hip
deepspeed/ops/csrc/transformer/hip/transform_kernels.hip
+576
-0
deepspeed/ops/csrc/transformer/normalize_kernels.cu
deepspeed/ops/csrc/transformer/normalize_kernels.cu
+2103
-0
deepspeed/ops/csrc/transformer/softmax_kernels.cu
deepspeed/ops/csrc/transformer/softmax_kernels.cu
+591
-0
deepspeed/ops/csrc/transformer/transform_kernels.cu
deepspeed/ops/csrc/transformer/transform_kernels.cu
+575
-0
deepspeed/ops/csrc/utils/flatten_unflatten.cpp
deepspeed/ops/csrc/utils/flatten_unflatten.cpp
+25
-0
deepspeed/ops/csrc/utils/hip/flatten_unflatten.cpp
deepspeed/ops/csrc/utils/hip/flatten_unflatten.cpp
+25
-0
deepspeed/ops/op_builder
deepspeed/ops/op_builder
+0
-1
deepspeed/ops/op_builder/__init__.py
deepspeed/ops/op_builder/__init__.py
+24
-0
deepspeed/ops/op_builder/builder.py
deepspeed/ops/op_builder/builder.py
+340
-0
deepspeed/ops/op_builder/cpu_adam.py
deepspeed/ops/op_builder/cpu_adam.py
+71
-0
deepspeed/ops/op_builder/fused_adam.py
deepspeed/ops/op_builder/fused_adam.py
+31
-0
deepspeed/ops/op_builder/fused_lamb.py
deepspeed/ops/op_builder/fused_lamb.py
+31
-0
deepspeed/ops/op_builder/sparse_attn.py
deepspeed/ops/op_builder/sparse_attn.py
+52
-0
No files found.
deepspeed/ops/csrc/transformer/hip/cublas_wrappers.hip
0 → 100644
View file @
eadbbe09
#include "cublas_wrappers.h"
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
deepspeed/ops/csrc/transformer/hip/dropout_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
deepspeed/ops/csrc/transformer/hip/ds_transformer_hip.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
size_t
get_workspace_size
(
int
maxBatchSize
,
int
seq_len
,
int
hidden_size
,
int
intermediate_size
,
int
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
size_t
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
int
layer_id
,
int
batch_size
,
int
hidden_size
,
int
num_heads
,
int
intermediate_size
,
int
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
(
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
if
(
std
::
is_same
<
T
,
__half
>::
value
)
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
int
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
int
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
int
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
int
layer_id
,
int
batch_size
,
int
hidden_dim
,
int
num_heads
,
int
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
int
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
int
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
int
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/transformer/hip/gelu_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int intermediate_size)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
int iterations = intermediate_size / blockDim.x / 4;
int row_stride = intermediate_size / 4;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, output, intermediate_size);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream, input, output, intermediate_size);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream, d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
deepspeed/ops/csrc/transformer/hip/general_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "general_kernels.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
deepspeed/ops/csrc/transformer/hip/normalize_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
deepspeed/ops/csrc/transformer/hip/softmax_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
namespace cg = cooperative_groups;
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#if __CUDA_ARCH__ >= 700
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, hipStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> 5;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 1>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 12>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 24>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
deepspeed/ops/csrc/transformer/hip/transform_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<__half>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<float>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<float>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<__half>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<float>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<__half>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
hipLaunchKernelGGL(( bias_add_transform_0213_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads);
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
hipLaunchKernelGGL(( transform4d_0213_v2), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim);
}
}
deepspeed/ops/csrc/transformer/normalize_kernels.cu
0 → 100644
View file @
eadbbe09
#include "custom_cuda_layers.h"
namespace
cg
=
cooperative_groups
;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__
void
fused_bias_residual_layer_norm
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
float
*
means
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
/
WARP_SIZE
;
float
vals_arr
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
residual
+=
(
row
*
row_stride
);
vals
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
residual
[
i
*
iteration_stride
+
id
];
sum
+=
vals_arr
[
i
];
}
if
(
high_index
<
row_stride
)
{
vals_arr
[
iterations
]
=
residual
[
high_index
];
sum
+=
vals_arr
[
iterations
];
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
sum
=
shr
[
g
.
thread_rank
()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
row_stride
;
if
(
training
)
if
(
g
.
thread_rank
()
==
0
)
means
[
row
]
=
mean
;
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
-=
mean
;
variance
+=
vals_arr
[
i
]
*
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
row_stride
;
variance
+=
epsilon
;
if
(
training
)
if
(
g
.
thread_rank
()
==
0
)
vars
[
row
]
=
variance
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
vals_arr
[
i
]
*
rsqrtf
(
variance
);
vals_arr
[
i
]
=
vals_arr
[
i
]
*
gamma
[
i
*
iteration_stride
+
id
]
+
beta
[
i
*
iteration_stride
+
id
];
vals
[
i
*
iteration_stride
+
id
]
=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
rsqrtf
(
variance
);
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
gamma
[
high_index
]
+
beta
[
high_index
];
vals
[
high_index
]
=
vals_arr
[
iterations
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
__half
*
means
,
int
row_stride
)
{
#if __CUDA_ARCH__ >= 700
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
float2
vals_f
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
const
__half2
*
residual_cast
=
reinterpret_cast
<
const
__half2
*>
(
residual
);
residual_cast
+=
(
row
*
row_stride
);
vals_cast
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
]
=
__half22float2
(
residual_cast
[
i
*
iteration_stride
+
id
]);
sum
+=
vals_f
[
i
].
x
;
sum
+=
vals_f
[
i
].
y
;
}
if
((
high_index
)
<
row_stride
)
{
vals_f
[
iterations
]
=
__half22float2
(
residual_cast
[
high_index
]);
sum
+=
vals_f
[
iterations
].
x
;
sum
+=
vals_f
[
iterations
].
y
;
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
sum
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
*
2
);
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
].
x
-=
mean
;
vals_f
[
i
].
y
-=
mean
;
variance
+=
vals_f
[
i
].
x
*
vals_f
[
i
].
x
;
variance
+=
vals_f
[
i
].
y
*
vals_f
[
i
].
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
(
row_stride
*
2
);
variance
+=
epsilon
;
__half2
variance_h
=
__float2half2_rn
(
variance
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
if
(
training
&&
g
.
thread_rank
()
==
0
)
{
vars
[
row
]
=
__float2half
(
variance
);
means
[
row
]
=
__float2half
(
mean
);
}
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
i
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
i
*
iteration_stride
+
id
]
+
beta_cast
[
i
*
iteration_stride
+
id
];
vals_cast
[
i
*
iteration_stride
+
id
]
=
vals_arr
;
}
if
((
high_index
)
<
row_stride
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
iterations
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
high_index
]
+
beta_cast
[
high_index
];
vals_cast
[
high_index
]
=
vals_arr
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual_layer_norm
(
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
T
*
vars
,
T
*
means
);
template
<
>
void
launch_bias_residual_layer_norm
<
float
>
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
float
*
means
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
batch_size
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
means
,
hidden_dim
);
}
template
<
>
void
launch_bias_residual_layer_norm
<
__half
>
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
__half
*
means
)
{
int
threads
=
128
;
dim3
grid_dim
(
batch_size
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
means
,
hidden_dim
/
2
);
}
__global__
void
fused_bias_residual_layer_norm
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
/
32
;
float
vals_arr
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
residual
+=
(
row
*
row_stride
);
vals
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
residual
[
i
*
iteration_stride
+
id
];
sum
+=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
residual
[
high_index
];
sum
+=
vals_arr
[
iterations
];
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
sum
=
shr
[
g
.
thread_rank
()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
row_stride
;
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
-=
mean
;
variance
+=
vals_arr
[
i
]
*
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
row_stride
;
variance
+=
epsilon
;
if
(
training
)
if
(
g
.
thread_rank
()
==
0
)
vars
[
row
]
=
variance
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
vals_arr
[
i
]
*
rsqrtf
(
variance
);
vals_arr
[
i
]
=
vals_arr
[
i
]
*
gamma
[
i
*
iteration_stride
+
id
]
+
beta
[
i
*
iteration_stride
+
id
];
vals
[
i
*
iteration_stride
+
id
]
=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
rsqrtf
(
variance
);
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
gamma
[
high_index
]
+
beta
[
high_index
];
vals
[
high_index
]
=
vals_arr
[
iterations
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
int
row_stride
)
{
#if __CUDA_ARCH__ >= 700
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
float2
vals_f
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
const
__half2
*
residual_cast
=
reinterpret_cast
<
const
__half2
*>
(
residual
);
residual_cast
+=
(
row
*
row_stride
);
vals_cast
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
]
=
__half22float2
(
residual_cast
[
i
*
iteration_stride
+
id
]);
sum
+=
vals_f
[
i
].
x
;
sum
+=
vals_f
[
i
].
y
;
}
if
((
high_index
)
<
row_stride
)
{
vals_f
[
iterations
]
=
__half22float2
(
residual_cast
[
high_index
]);
sum
+=
vals_f
[
iterations
].
x
;
sum
+=
vals_f
[
iterations
].
y
;
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
sum
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
*
2
);
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
].
x
-=
mean
;
vals_f
[
i
].
y
-=
mean
;
variance
+=
vals_f
[
i
].
x
*
vals_f
[
i
].
x
;
variance
+=
vals_f
[
i
].
y
*
vals_f
[
i
].
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
5
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
5
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
(
row_stride
*
2
);
variance
+=
epsilon
;
__half2
variance_h
=
__float2half2_rn
(
variance
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
if
(
training
&&
g
.
thread_rank
()
==
0
)
vars
[
row
]
=
__float2half
(
variance
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
i
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
i
*
iteration_stride
+
id
]
+
beta_cast
[
i
*
iteration_stride
+
id
];
vals_cast
[
i
*
iteration_stride
+
id
]
=
vals_arr
;
}
if
((
high_index
)
<
row_stride
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
iterations
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
high_index
]
+
beta_cast
[
high_index
];
vals_cast
[
high_index
]
=
vals_arr
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual_layer_norm
(
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
T
*
vars
);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template
<
>
void
launch_bias_residual_layer_norm
<
float
>
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
batch_size
);
// There are some limitations to call below functions, now just enumerate the situations.
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
hidden_dim
);
}
template
<
>
void
launch_bias_residual_layer_norm
<
__half
>
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
)
{
int
threads
=
128
;
dim3
grid_dim
(
batch_size
);
// There are some limitations to call below functions, now just enumerate the situations.
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
hidden_dim
/
2
);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template
<
typename
T
>
__global__
void
LayerNormBackward1
(
const
T
*
__restrict__
out_grad
,
const
T
*
__restrict__
vals_hat
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
betta
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
,
bool
invertible
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
float
betta_reg
=
(
invertible
?
(
float
)
betta
[
idx
]
:
0.0
f
);
float
gamma_reg
=
(
float
)
gamma
[
idx
];
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad
[
offset
];
float
val
=
(
invertible
?
((
float
)
vals_hat
[
offset
]
-
betta_reg
)
/
gamma_reg
:
(
float
)
vals_hat
[
offset
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template
<
typename
T
>
__global__
void
LayerNormBackward1
(
const
T
*
__restrict__
out_grad
,
const
T
*
__restrict__
X_data
,
const
T
*
__restrict__
vars
,
const
T
*
__restrict__
means
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad
[
offset
];
float
val
=
(
float
)
X_data
[
offset
];
val
=
(
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__
void
LayerNormBackward2
(
const
float
*
out_grad
,
const
float
*
vals_hat
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
float
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
THREADS
<
row_stride
?
THREADS
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad
+=
(
row
*
row_stride
);
vals_hat
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat
[
i
*
iteration_stride
+
id
]
-
betta
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat
[
high_index
]
-
betta
[
high_index
])
/
gamma_reg
:
vals_hat
[
high_index
]);
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
sqrtf
(
var_reg
);
// dval_hat = gamma * (x - u) * out_grad
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
// dvar_inv = gamma * out_grad / sqrt(var)
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
((
-
sum
*
vals_hat_arr
[
i
])
/
var_reg
);
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
);
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
);
}
__global__
void
LayerNormBackward2
(
const
__half
*
out_grad
,
const
__half
*
vals_hat
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
__half
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
iteration_stride
<
row_stride
?
iteration_stride
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h
=
reinterpret_cast
<
const
__half2
*>
(
out_grad
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
vals_hat
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
betta_h
=
(
invertible
?
reinterpret_cast
<
const
__half2
*>
(
betta
)
:
nullptr
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
betta_h
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat_h
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat_h
[
high_index
]
-
betta_h
[
high_index
])
/
gamma_reg
:
vals_hat_h
[
high_index
]);
iterations
++
;
}
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
result_h
=
(
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
h2sqrt
(
var_reg
));
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
temp
=
((
-
sum_h
*
vals_hat_arr
[
i
])
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
temp_f
=
__half22float2
(
temp
);
vals_arr_f
[
i
].
x
+=
temp_f
.
x
;
vals_arr_f
[
i
].
y
+=
temp_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
;
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
;
}
}
template
<
>
void
launch_layerNorm_backward
<
float
>
(
const
float
*
out_grad
,
const
float
*
vals_hat
,
const
float
*
vars
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
float
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward
<
__half
>
(
const
__half
*
out_grad
,
const
__half
*
vals_hat
,
const
__half
*
vars
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
__half
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
/
2
);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__
void
LayerNormBackward2
(
const
float
*
out_grad
,
const
float
*
X_vals
,
const
float
*
gamma
,
const
float
*
vars
,
const
float
*
means
,
float
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
THREADS
<
row_stride
?
THREADS
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad
+=
(
row
*
row_stride
);
X_vals
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
mean_reg
=
means
[
row
];
float
sum
=
0
;
float
xu
[
NORM_REG
];
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
X_vals
[
i
*
iteration_stride
+
id
]
-
mean_reg
);
sum
+=
vals_arr
[
i
]
*
xu
[
i
];
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
(
-
sum
*
xu
[
i
]
*
rsqrtf
(
var_reg
)
/
(
var_reg
));
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
);
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
);
}
__global__
void
LayerNormBackward2
(
const
__half
*
out_grad
,
const
__half
*
X_vals
,
const
__half
*
gamma
,
const
__half
*
vars
,
const
__half
*
means
,
__half
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
iteration_stride
<
row_stride
?
iteration_stride
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h
=
reinterpret_cast
<
const
__half2
*>
(
out_grad
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
X_vals
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
iterations
++
;
}
__half
mean_h
=
means
[
row
];
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
__half2
mean_reg
=
__halves2half2
(
mean_h
,
mean_h
);
__half2
xu
[
NORM_REG
];
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
mean_reg
);
__half2
result_h
=
(
xu
[
i
]
*
vals_arr
[
i
]);
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
xu_grad
=
((
-
sum_h
*
xu
[
i
]
*
h2rsqrt
(
var_reg
))
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
xu_grad_f
=
__half22float2
(
xu_grad
);
vals_arr_f
[
i
].
x
+=
xu_grad_f
.
x
;
vals_arr_f
[
i
].
y
+=
xu_grad_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
;
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
;
}
}
template
<
>
void
launch_layerNorm_backward
<
float
>
(
const
float
*
out_grad
,
const
float
*
X_data
,
const
float
*
vars
,
const
float
*
means
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward
<
__half
>
(
const
__half
*
out_grad
,
const
__half
*
X_data
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
/
2
);
}
template
<
typename
T
>
__global__
void
LayerNormBackward1_fused_add
(
const
T
*
__restrict__
out_grad1
,
const
T
*
__restrict__
out_grad2
,
const
T
*
__restrict__
vals_hat
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
betta
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
,
bool
invertible
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
float
betta_reg
=
(
invertible
?
(
float
)
betta
[
idx
]
:
0.0
f
);
float
gamma_reg
=
(
float
)
gamma
[
idx
];
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad1
[
offset
]
+
(
float
)
out_grad2
[
offset
];
float
val
=
(
invertible
?
((
float
)
vals_hat
[
offset
]
-
betta_reg
)
/
gamma_reg
:
(
float
)
vals_hat
[
offset
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
template
<
typename
T
>
__global__
void
LayerNormBackward1_fused_add
(
const
T
*
__restrict__
out_grad1
,
const
T
*
__restrict__
out_grad2
,
const
T
*
__restrict__
X_data
,
const
T
*
__restrict__
vars
,
const
T
*
__restrict__
means
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad1
[
offset
]
+
(
float
)
out_grad2
[
offset
];
float
val
=
(
float
)
X_data
[
offset
];
val
=
(
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
__global__
void
LayerNormBackward2_fused_add
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
vals_hat
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
float
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
THREADS
<
row_stride
?
THREADS
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad1
+=
(
row
*
row_stride
);
out_grad2
+=
(
row
*
row_stride
);
vals_hat
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat
[
i
*
iteration_stride
+
id
]
-
betta
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat
[
high_index
]
-
betta
[
high_index
])
/
gamma_reg
:
vals_hat
[
high_index
]);
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
sqrtf
(
var_reg
);
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
((
-
sum
*
vals_hat_arr
[
i
])
/
var_reg
);
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
)
+
out_grad2
[
i
*
iteration_stride
+
id
];
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
)
+
out_grad2
[
high_index
];
}
__global__
void
LayerNormBackward2_fused_add
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
vals_hat
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
__half
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
iteration_stride
<
row_stride
?
iteration_stride
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
// float2 result[iterations];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h1
=
reinterpret_cast
<
const
__half2
*>
(
out_grad1
);
const
__half2
*
out_grad_h2
=
reinterpret_cast
<
const
__half2
*>
(
out_grad2
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
vals_hat
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h1
+=
(
row
*
row_stride
);
out_grad_h2
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
betta_h
=
(
invertible
?
reinterpret_cast
<
const
__half2
*>
(
betta
)
:
nullptr
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
betta_h
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat_h
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat_h
[
high_index
]
-
betta_h
[
high_index
])
/
gamma_reg
:
vals_hat_h
[
high_index
]);
iterations
++
;
}
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
result_h
=
(
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
h2sqrt
(
var_reg
));
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
temp
=
((
-
sum_h
*
vals_hat_arr
[
i
])
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
temp_f
=
__half22float2
(
temp
);
vals_arr_f
[
i
].
x
+=
temp_f
.
x
;
vals_arr_f
[
i
].
y
+=
temp_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
+
out_grad_h2
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
+
out_grad_h2
[
high_index
];
}
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
float
>
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
vals_hat
,
const
float
*
vars
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
float
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
__half
>
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
vals_hat
,
const
__half
*
vars
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
__half
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
/
2
);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__
void
LayerNormBackward2_fused_add
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
X_vals
,
const
float
*
gamma
,
const
float
*
vars
,
const
float
*
means
,
float
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
THREADS
<
row_stride
?
THREADS
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
out_grad1
+=
(
row
*
row_stride
);
out_grad2
+=
(
row
*
row_stride
);
X_vals
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
X_vals
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
X_vals
[
high_index
];
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
mean_reg
=
means
[
row
];
float
sum
=
0
;
float
xu
[
NORM_REG
];
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
vals_hat_arr
[
i
]
-
mean_reg
);
sum
+=
vals_arr
[
i
]
*
xu
[
i
];
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
(
-
sum
*
xu
[
i
]
*
rsqrtf
(
var_reg
)
/
(
var_reg
));
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
)
+
out_grad2
[
i
*
iteration_stride
+
id
];
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
)
+
out_grad2
[
high_index
];
}
__global__
void
LayerNormBackward2_fused_add
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
X_vals
,
const
__half
*
gamma
,
const
__half
*
vars
,
const
__half
*
means
,
__half
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
(
iteration_stride
<
row_stride
?
iteration_stride
:
row_stride
)
/
WARP_SIZE
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h1
=
reinterpret_cast
<
const
__half2
*>
(
out_grad1
);
const
__half2
*
out_grad_h2
=
reinterpret_cast
<
const
__half2
*>
(
out_grad2
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
X_vals
);
out_grad_h1
+=
(
row
*
row_stride
);
out_grad_h2
+=
(
row
*
row_stride
);
inp_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
i
]
=
vals_hat_h
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
iterations
]
=
vals_hat_h
[
high_index
];
iterations
++
;
}
__half
mean_h
=
means
[
row
];
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
__half2
mean_reg
=
__halves2half2
(
mean_h
,
mean_h
);
__half2
xu
[
NORM_REG
];
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
vals_hat_arr
[
i
]
-
mean_reg
);
__half2
result_h
=
(
xu
[
i
]
*
vals_arr
[
i
]);
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
xu_grad
=
((
-
sum_h
*
xu
[
i
]
*
h2rsqrt
(
var_reg
))
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
xu_grad_f
=
__half22float2
(
xu_grad
);
vals_arr_f
[
i
].
x
+=
xu_grad_f
.
x
;
vals_arr_f
[
i
].
y
+=
xu_grad_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
+
out_grad_h2
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
+
out_grad_h2
[
high_index
];
}
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
float
>
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
X_data
,
const
float
*
vars
,
const
float
*
means
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
__half
>
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
X_data
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
/
2
);
}
deepspeed/ops/csrc/transformer/softmax_kernels.cu
0 → 100644
View file @
eadbbe09
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
// Fused attention + softmax
template
<
int
tbSize
,
int
blockStride
,
int
tbSeq
>
__global__
void
attn_softmax
(
float
*
vals
,
const
float
*
attn_mask
,
int
heads
,
int
seq_length
,
int
iterations
)
{
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
5
;
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
batch
=
blockIdx
.
x
;
int
row
=
blockIdx
.
y
;
int
max_threads_in_sequence
=
std
::
max
(
seq_length
,
tbSeq
);
int
seq_lane
=
threadIdx
.
x
%
max_threads_in_sequence
;
int
data_offset
=
batch
*
(
gridDim
.
y
*
block_width
)
+
row
*
block_width
+
(
threadIdx
.
x
/
max_threads_in_sequence
)
*
seq_length
;
int
mask_offset
=
batch
*
seq_length
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
float4
*
val_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
attn_mask_cast
=
reinterpret_cast
<
const
float4
*>
(
attn_mask
);
float4
data
[
MAX_THREAD_ITERATIONS
];
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float4
mask
=
attn_mask_cast
[
mask_offset
+
data_id
];
data
[
i
]
=
val_cast
[
data_offset
+
data_id
];
data
[
i
].
x
+=
mask
.
x
;
data
[
i
].
y
+=
mask
.
y
;
data
[
i
].
z
+=
mask
.
z
;
data
[
i
].
w
+=
mask
.
w
;
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
else
{
data
[
i
].
x
=
minus_infinity
;
data
[
i
].
y
=
minus_infinity
;
data
[
i
].
z
=
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
tbSize
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
=
__expf
(
data
[
i
].
x
-
max_val
);
data
[
i
].
y
=
__expf
(
data
[
i
].
y
-
max_val
);
data
[
i
].
z
=
__expf
(
data
[
i
].
z
-
max_val
);
data
[
i
].
w
=
__expf
(
data
[
i
].
w
-
max_val
);
sum
+=
(
data
[
i
].
x
+
data
[
i
].
y
+
data
[
i
].
z
+
data
[
i
].
w
);
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
tbSize
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
/=
sum
;
data
[
i
].
y
/=
sum
;
data
[
i
].
z
/=
sum
;
data
[
i
].
w
/=
sum
;
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
val_cast
[
data_offset
+
data_id
]
=
data
[
i
];
}
}
template
<
int
tbSize
,
int
blockStride
,
int
tbSeq
>
__global__
void
attn_softmax
(
__half
*
vals
,
const
__half
*
attn_mask
,
int
heads
,
int
seq_length
,
int
iterations
)
{
#if __CUDA_ARCH__ >= 700
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
5
;
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
batch
=
blockIdx
.
x
;
int
row
=
blockIdx
.
y
;
int
max_threads_in_sequence
=
std
::
max
(
seq_length
,
tbSeq
);
int
seq_lane
=
threadIdx
.
x
%
max_threads_in_sequence
;
int
data_offset
=
batch
*
(
gridDim
.
y
*
block_width
)
+
row
*
block_width
+
(
threadIdx
.
x
/
max_threads_in_sequence
)
*
seq_length
;
int
mask_offset
=
batch
*
seq_length
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
float2
*
val_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
attn_mask_cast
=
reinterpret_cast
<
const
float2
*>
(
attn_mask
);
val_cast
+=
data_offset
;
attn_mask_cast
+=
mask_offset
;
float2
low_data
[
MAX_THREAD_ITERATIONS
];
float2
high_data
[
MAX_THREAD_ITERATIONS
];
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float2
data
=
val_cast
[
data_id
];
float2
mask
=
attn_mask_cast
[
data_id
];
__half2
*
data_arr
=
reinterpret_cast
<
__half2
*>
(
&
data
);
__half2
*
mask_arr
=
reinterpret_cast
<
__half2
*>
(
&
mask
);
low_data
[
i
]
=
__half22float2
(
data_arr
[
0
]);
high_data
[
i
]
=
__half22float2
(
data_arr
[
1
]);
float2
low_mask
=
__half22float2
(
mask_arr
[
0
]);
float2
high_mask
=
__half22float2
(
mask_arr
[
1
]);
low_data
[
i
].
x
+=
low_mask
.
x
;
low_data
[
i
].
y
+=
low_mask
.
y
;
high_data
[
i
].
x
+=
high_mask
.
x
;
high_data
[
i
].
y
+=
high_mask
.
y
;
max_val
=
(
low_data
[
i
].
x
>
max_val
?
low_data
[
i
].
x
:
max_val
);
max_val
=
(
low_data
[
i
].
y
>
max_val
?
low_data
[
i
].
y
:
max_val
);
max_val
=
(
high_data
[
i
].
x
>
max_val
?
high_data
[
i
].
x
:
max_val
);
max_val
=
(
high_data
[
i
].
y
>
max_val
?
high_data
[
i
].
y
:
max_val
);
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
tbSize
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
low_data
[
i
].
x
=
__expf
(
low_data
[
i
].
x
-
max_val
);
low_data
[
i
].
y
=
__expf
(
low_data
[
i
].
y
-
max_val
);
high_data
[
i
].
x
=
__expf
(
high_data
[
i
].
x
-
max_val
);
high_data
[
i
].
y
=
__expf
(
high_data
[
i
].
y
-
max_val
);
sum
+=
(
low_data
[
i
].
x
+
low_data
[
i
].
y
+
high_data
[
i
].
x
+
high_data
[
i
].
y
);
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
tbSize
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
low_data
[
i
].
x
/=
sum
;
low_data
[
i
].
y
/=
sum
;
high_data
[
i
].
x
/=
sum
;
high_data
[
i
].
y
/=
sum
;
result_h
[
0
]
=
__float22half2_rn
(
low_data
[
i
]);
result_h
[
1
]
=
__float22half2_rn
(
high_data
[
i
]);
val_cast
[
data_id
]
=
result_f
;
}
}
#endif
}
template
<
typename
T
>
void
launch_attn_softmax
(
T
*
,
const
T
*
,
int
,
int
,
int
,
cudaStream_t
);
template
<
>
void
launch_attn_softmax
<
float
>
(
float
*
vals
,
const
float
*
attn_mask
,
int
batch_size
,
int
heads
,
int
sequence_length
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
int
seq_length4
=
sequence_length
/
4
;
int
block_compute_size
=
(
seq_length4
<
threads
?
(
int
)
pow
(
2.0
,
floor
(
log2
((
float
)(
threads
/
seq_length4
))))
:
1
);
dim3
grid_dim
(
batch_size
,
heads
*
sequence_length
/
block_compute_size
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
int
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
8
)
attn_softmax
<
2
,
(
threads
/
2
),
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
16
)
attn_softmax
<
4
,
(
threads
/
4
),
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
32
)
attn_softmax
<
8
,
(
threads
/
8
),
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
64
)
attn_softmax
<
16
,
(
threads
/
16
),
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
128
)
attn_softmax
<
32
,
(
threads
/
32
),
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
256
)
attn_softmax
<
32
,
(
threads
/
64
),
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
{
const
int
threads
=
256
;
block_compute_size
=
(
seq_length4
<
threads
?
(
int
)
pow
(
2.0
,
floor
(
log2
((
float
)(
threads
/
seq_length4
))))
:
1
);
dim3
grid_dim
(
batch_size
,
heads
*
sequence_length
/
block_compute_size
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
512
)
attn_softmax
<
32
,
(
threads
/
128
),
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<
(
MAX_THREADS
*
MAX_THREAD_ITERATIONS
*
4
))
attn_softmax
<
32
,
1
,
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!"
);
}
}
template
<
>
void
launch_attn_softmax
<
__half
>
(
__half
*
vals
,
const
__half
*
attn_mask
,
int
batch_size
,
int
heads
,
int
sequence_length
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
int
seq_length4
=
sequence_length
/
4
;
int
block_compute_size
=
(
seq_length4
<
threads
?
(
int
)
pow
(
2.0
,
floor
(
log2
((
float
)(
threads
/
seq_length4
))))
:
1
);
dim3
grid_dim
(
batch_size
,
heads
*
sequence_length
/
block_compute_size
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
int
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
8
)
attn_softmax
<
2
,
(
threads
/
2
),
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
16
)
attn_softmax
<
4
,
(
threads
/
4
),
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
32
)
attn_softmax
<
8
,
(
threads
/
8
),
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
64
)
attn_softmax
<
16
,
(
threads
/
16
),
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
128
)
attn_softmax
<
32
,
(
threads
/
32
),
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
256
)
attn_softmax
<
32
,
(
threads
/
64
),
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
{
const
int
threads
=
256
;
block_compute_size
=
(
seq_length4
<
threads
?
(
int
)
pow
(
2.0
,
floor
(
log2
((
float
)(
threads
/
seq_length4
))))
:
1
);
dim3
grid_dim
(
batch_size
,
heads
*
sequence_length
/
block_compute_size
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
512
)
attn_softmax
<
32
,
(
threads
/
128
),
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<
(
MAX_THREADS
*
MAX_THREAD_ITERATIONS
*
4
))
attn_softmax
<
32
,
1
,
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!"
);
}
}
template
<
typename
T
,
int
tbSize
,
int
blockStride
>
__global__
void
softmax_backward_kernel
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
seq_length
)
{
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
5
;
// warp-count = num_threads / WARP_SIZE (32)
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
int
iterations
=
(
seq_length
<
(
MAX_THREAD_ITERATIONS
*
iteration_stride
)
?
(
seq_length
+
iteration_stride
-
1
)
/
iteration_stride
:
MAX_THREAD_ITERATIONS
);
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
T
val_reg
[
MAX_THREAD_ITERATIONS
];
T
soft_reg
[
MAX_THREAD_ITERATIONS
];
float
grad_reg
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
id
;
if
(
data_id
<
block_width
)
{
val_reg
[
i
]
=
out_grad
[
row
*
block_width
+
data_id
];
soft_reg
[
i
]
=
soft_inp
[
row
*
block_width
+
data_id
];
grad_reg
+=
((
float
)
val_reg
[
i
]
*
(
float
)
soft_reg
[
i
]);
// if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
grad_reg
+=
g
.
shfl_xor
(
grad_reg
,
i
);
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
grad_reg
;
b
.
sync
();
if
(
lane
<
warp_num
)
grad_reg
=
partialSum
[
lane
];
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
seq_length
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
grad_reg
+=
g
.
shfl_xor
(
grad_reg
,
i
);
grad_reg
=
g
.
shfl
(
grad_reg
,
id
/
tbSize
);
}
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
id
;
if
(
data_id
<
block_width
)
{
float
temp
=
(
float
)
soft_reg
[
i
]
*
((
float
)
val_reg
[
i
]
-
grad_reg
);
out_grad
[
row
*
block_width
+
data_id
]
=
(
T
)
temp
;
}
}
}
template
<
typename
T
,
int
ITERATIONS
>
__global__
void
softmax_backward_kernel_v2
(
T
*
grad
/* input & output*/
,
const
T
*
output
,
int
softmax_length
)
{
int
batch_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
offset
=
batch_idx
*
softmax_length
+
threadIdx
.
x
;
grad
+=
offset
;
output
+=
offset
;
T
grad_reg
[
ITERATIONS
];
T
output_reg
[
ITERATIONS
];
float
sum
=
0.0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
{
grad_reg
[
i
]
=
grad
[
i
*
WARP_SIZE
];
output_reg
[
i
]
=
output
[
i
*
WARP_SIZE
];
sum
+=
(
float
)
grad_reg
[
i
]
*
(
float
)
output_reg
[
i
];
}
}
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
grad
[
i
*
WARP_SIZE
]
=
(
float
)
output_reg
[
i
]
*
((
float
)
grad_reg
[
i
]
-
sum
);
}
}
template
<
typename
T
>
void
launch_attn_softmax_backward_v2
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
)
{
const
int
warps_per_block
=
4
;
dim3
grid_dim
(
batch_size
*
heads
*
seq_length
/
warps_per_block
);
dim3
block_dim
(
WARP_SIZE
,
warps_per_block
);
if
(
seq_length
<=
32
)
softmax_backward_kernel_v2
<
T
,
1
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
64
)
softmax_backward_kernel_v2
<
T
,
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
128
)
softmax_backward_kernel_v2
<
T
,
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
256
)
softmax_backward_kernel_v2
<
T
,
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
384
)
softmax_backward_kernel_v2
<
T
,
12
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
512
)
softmax_backward_kernel_v2
<
T
,
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
768
)
softmax_backward_kernel_v2
<
T
,
24
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
1024
)
softmax_backward_kernel_v2
<
T
,
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
2048
)
softmax_backward_kernel_v2
<
T
,
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
throw
std
::
runtime_error
(
std
::
string
(
"Special sequence length found in softmax backward, seq_length: "
)
+
std
::
to_string
(
seq_length
));
}
template
void
launch_attn_softmax_backward_v2
<
__half
>(
__half
*
out_grad
,
const
__half
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_backward_v2
<
float
>(
float
*
out_grad
,
const
float
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
);
deepspeed/ops/csrc/transformer/transform_kernels.cu
0 → 100755
View file @
eadbbe09
#include "custom_cuda_layers.h"
#define rows_trans 16
#define cols_trans 16
template
<
typename
T
>
__global__
void
Transpose_Kernel
(
const
T
*
inp
,
T
*
out
,
int
row_width
,
int
col_width
)
{
__shared__
T
data_block
[
rows_trans
*
(
cols_trans
+
1
)];
int
r
=
threadIdx
.
x
/
cols_trans
;
int
c
=
threadIdx
.
x
%
cols_trans
;
int
m
=
row_width
/
cols_trans
;
int
i
=
blockIdx
.
x
/
m
*
rows_trans
+
r
;
int
j
=
blockIdx
.
x
%
m
*
cols_trans
+
c
;
int
row_stride
=
rows_trans
/
((
rows_trans
*
cols_trans
+
THREADS
-
1
)
/
THREADS
);
for
(
int
k
=
0
;
k
<
rows_trans
;
k
+=
row_stride
)
data_block
[(
k
+
r
)
*
cols_trans
+
c
]
=
inp
[(
i
+
k
)
*
row_width
+
j
];
__syncthreads
();
i
=
blockIdx
.
x
%
m
*
rows_trans
+
r
;
j
=
blockIdx
.
x
/
m
*
cols_trans
+
c
;
for
(
int
k
=
0
;
k
<
rows_trans
;
k
+=
row_stride
)
out
[(
i
+
k
)
*
col_width
+
j
]
=
data_block
[
c
*
cols_trans
+
r
+
k
];
}
template
<
>
void
Transpose
<
__half
>
(
const
__half
*
inp_mat
,
__half
*
out_mat
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
int
threads
=
THREADS
;
Transpose_Kernel
<
__half
><<<
(
rows
*
cols
+
threads
-
1
)
/
threads
,
threads
,
0
,
stream
>>>
(
inp_mat
,
out_mat
,
cols
,
rows
);
}
template
<
>
void
Transpose
<
float
>
(
const
float
*
inp_mat
,
float
*
out_mat
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
int
threads
=
THREADS
;
Transpose_Kernel
<
float
><<<
(
rows
*
cols
+
threads
-
1
)
/
threads
,
threads
,
0
,
stream
>>>
(
inp_mat
,
out_mat
,
cols
,
rows
);
}
template
<
typename
T
>
__global__
void
transform_0213
(
T
*
output
,
const
T
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
);
template
<
>
__global__
void
transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
head_ext
;
// Sequence ID (0-127)
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
y
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
float4
inputs
=
vals_vec
[
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
output_vec
[
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
inputs
;
}
template
<
>
__global__
void
transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
#if __CUDA_ARCH__ >= 700
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
head_ext
;
// Sequence ID (0-127)
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
y
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
[
1
];
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
vals_arr
[
0
]
=
vals_vec
[
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
output_vec
[
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
vals_arr
[
0
];
#endif
}
template
<
>
void
launch_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
(
seq_length
*
head_ext
));
transform_0213
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
template
<
>
void
launch_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
(
seq_length
*
head_ext
));
transform_0213
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
// Bias add
template
<
typename
T
>
__global__
void
bias_add_transform_0213
(
T
*
output
,
const
T
*
vals
,
const
T
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
);
template
<
>
__global__
void
bias_add_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
const
float
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
blockIdx
.
z
/
head_ext
;
// Hidden count
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
float4
inputs
=
vals_vec
[
d0
*
d0_stride
*
(
gridDim
.
z
/
head_ext
)
+
cnt
*
d1_stride
+
d1
*
d1_stride
*
(
gridDim
.
z
/
head_ext
)
+
d2
*
d2_stride
+
d3
];
float4
biases
=
bias_vec
[
cnt
*
d1_stride
+
d2
*
d2_stride
+
d3
];
float4
outputs
;
outputs
.
x
=
inputs
.
x
+
biases
.
x
;
outputs
.
y
=
inputs
.
y
+
biases
.
y
;
outputs
.
z
=
inputs
.
z
+
biases
.
z
;
outputs
.
w
=
inputs
.
w
+
biases
.
w
;
output_vec
[
cnt
*
d0_out_stride
*
gridDim
.
x
+
d0
*
d0_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
+
d3
]
=
outputs
;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template
<
>
__global__
void
bias_add_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
,
int
head_ext
)
{
#if __CUDA_ARCH__ >= 700
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
blockIdx
.
z
/
head_ext
;
// Hidden count
int
d2
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
;
float4
bias_arr
;
float4
output_arr
;
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_arr
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_arr
);
__half2
*
output_half
=
reinterpret_cast
<
__half2
*>
(
&
output_arr
);
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
vals_vec
+=
(
d0
*
d0_stride
*
(
gridDim
.
z
/
head_ext
));
vals_vec
+=
(
d1
*
d1_stride
*
(
gridDim
.
z
/
head_ext
));
vals_vec
+=
(
cnt
*
d1_stride
);
vals_vec
+=
(
d2
*
d2_stride
);
bias_vec
+=
(
cnt
*
d1_stride
);
bias_vec
+=
(
d2
*
d2_stride
);
output_vec
+=
(
cnt
*
d0_stride
*
gridDim
.
x
);
output_vec
+=
(
d1
*
d2_stride
);
output_vec
+=
(
d0
*
d0_stride
);
output_vec
+=
(
d2
*
d2_out_stride
);
bias_arr
=
bias_vec
[
d3
];
vals_arr
=
vals_vec
[
d3
];
#if defined(__ACC_HALF__)
output_half
[
0
]
=
vals_half
[
0
]
+
bias_half
[
0
];
output_half
[
1
]
=
vals_half
[
1
]
+
bias_half
[
1
];
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
#else
float2
bias_arr_f
[
4
];
float2
vals_arr_f
[
4
];
#pragma unroll
for
(
int
l
=
0
;
l
<
4
;
l
++
)
{
bias_arr_f
[
l
]
=
__half22float2
(
bias_half
[
l
]);
vals_arr_f
[
l
]
=
__half22float2
(
vals_half
[
l
]);
vals_arr_f
[
l
].
x
+=
bias_arr_f
[
l
].
x
;
vals_arr_f
[
l
].
y
+=
bias_arr_f
[
l
].
y
;
output_half
[
l
]
=
__float22half2_rn
(
vals_arr_f
[
l
]);
}
#endif
output_vec
[
d3
]
=
output_arr
;
#endif
}
__global__
void
bias_add_transform_0213_v2
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
hidden_dim
,
int
seq_length
,
int
heads
)
{
#if __CUDA_ARCH__ >= 700
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
iteration_stride
=
d1_stride
*
blockDim
.
z
;
// Hidden * 3 / 8
int
batch_stride
=
d0_stride
*
blockDim
.
z
;
// Hidden * S * 3 / 8
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
d2_stride
*
seq_length
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
;
// Sequence ID (0-127)
int
cnt
=
threadIdx
.
z
;
// blockIdx.z; // Hidden count
int
d2
=
threadIdx
.
y
;
// Head (0-11)
int
d3
=
threadIdx
.
x
;
// Values (groups of 4)
float4
vals_arr
[
1
];
float4
bias_arr
[
1
];
float4
output_arr
[
1
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
vals_arr
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
bias_arr
);
__half2
*
output_half
=
reinterpret_cast
<
__half2
*>
(
output_arr
);
const
float4
*
vals_vec
=
reinterpret_cast
<
const
float4
*>
(
vals
);
const
float4
*
bias_vec
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
output_vec
=
reinterpret_cast
<
float4
*>
(
output
);
int
iter_index
=
cnt
*
d1_stride
+
d2
*
d2_stride
+
d3
;
int
input_offset
=
d0
*
batch_stride
+
d1
*
(
iteration_stride
<<
1
);
bias_arr
[
0
]
=
bias_vec
[
iter_index
];
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
vals_arr
[
0
]
=
vals_vec
[
input_offset
+
iter_id
];
output_half
[
0
]
=
vals_half
[
0
]
+
bias_half
[
0
];
output_half
[
1
]
=
vals_half
[
1
]
+
bias_half
[
1
];
output_half
[
2
]
=
vals_half
[
2
]
+
bias_half
[
2
];
output_half
[
3
]
=
vals_half
[
3
]
+
bias_half
[
3
];
in_data
[
iter_id
]
=
output_arr
[
0
];
}
__syncthreads
();
iteration_stride
=
blockDim
.
z
*
(
blockDim
.
y
>>
1
);
int
matrix_stride
=
(
d0_out_stride
*
gridDim
.
x
);
int
head_count
=
(
d2
>>
1
)
+
cnt
*
(
blockDim
.
y
>>
1
);
int
out_index
=
d0
*
d0_out_stride
+
d1
*
(
d1_out_stride
<<
1
)
+
d3
+
(
d2
%
2
)
*
d2_stride
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_row
=
(
iter
*
iteration_stride
)
+
head_count
;
int
iter_offset
=
(
iter_row
%
blockDim
.
y
)
*
d2_out_stride
+
(
iter_row
/
blockDim
.
y
)
*
matrix_stride
;
output_vec
[
out_index
+
iter_offset
]
=
in_data
[
iter_row
*
d2_stride
+
d3
+
(
d2
%
2
)
*
(
d1_stride
*
blockDim
.
z
)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template
<
>
void
launch_bias_add_transform_0213
<
float
>
(
float
*
output
,
const
float
*
vals
,
const
float
*
bias
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
2
;
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
seq_length
,
(
trans_count
*
head_ext
));
bias_add_transform_0213
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
template
<
>
void
launch_bias_add_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
bias
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
int
heads
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
3
;
if
(
hidden_dim
>
128
||
hidden_dim
<
16
)
{
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
block_dim
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
dim3
grid_dim
(
batch_size
,
seq_length
,
(
trans_count
*
head_ext
));
bias_add_transform_0213
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
,
head_ext
);
}
else
{
dim3
block_dim
(
hidden_dim
/
heads
,
heads
,
trans_count
);
dim3
grid_dim
(
batch_size
,
seq_length
/
2
);
bias_add_transform_0213_v2
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
vals
,
bias
,
hidden_dim
,
seq_length
,
heads
);
}
}
template
<
typename
T
>
__global__
void
transform4d_0213
(
T
*
out
,
const
T
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
);
template
<
>
__global__
void
transform4d_0213
<
float
>
(
float
*
out
,
const
float
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
)
{
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
d0_stride
/
heads
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0_out_stride
=
d0_stride
;
int
d1_out_stride
=
d2_stride
;
int
d2_out_stride
=
hidden_dim
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
blockIdx
.
y
/
((
seq_length
-
1
)
/
blockDim
.
y
+
1
);
// Head
int
d2
=
(
threadIdx
.
y
+
blockDim
.
y
*
blockIdx
.
y
)
%
seq_length
;
int
cnt
=
blockIdx
.
z
;
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
if
(
d2
<
seq_length
)
{
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
float4
vals_vec
=
in_vec
[
cnt
*
d0_stride
*
gridDim
.
x
+
d0
*
d0_stride
+
d1
*
d1_stride
+
d2
*
d2_stride
+
d3
];
out_vec
[
d0
*
d0_out_stride
*
gridDim
.
z
+
cnt
*
d2_out_stride
+
d1
*
d1_out_stride
+
d2
*
d2_out_stride
*
gridDim
.
z
+
d3
]
=
vals_vec
;
}
}
template
<
>
__global__
void
transform4d_0213
<
__half
>
(
__half
*
out
,
const
__half
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
int
head_ext
)
{
#if __CUDA_ARCH__ >= 700
int
d0_stride
=
hidden_dim
*
(
seq_length
/
head_ext
);
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
threadIdx
.
y
+
(
blockIdx
.
z
%
head_ext
)
*
(
heads
/
head_ext
);
// Head
int
d2
=
blockIdx
.
z
/
head_ext
;
// Sequence
int
cnt
=
blockIdx
.
y
;
// Hidden count
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
in_vec
+=
(
cnt
*
d0_stride
*
gridDim
.
x
);
in_vec
+=
(
d0
*
d0_stride
);
in_vec
+=
(
d2
*
d2_stride
);
in_vec
+=
(
d1
*
d2_stride
*
seq_length
);
out_vec
+=
(
cnt
*
d1_stride
);
out_vec
+=
(
d1
*
d2_stride
);
out_vec
+=
(
d0
*
d0_stride
*
gridDim
.
y
);
out_vec
+=
(
d2
*
d1_stride
*
gridDim
.
y
);
out_vec
[
d3
]
=
in_vec
[
d3
];
#endif
}
__global__
void
transform4d_0213_v2
(
__half
*
out
,
const
__half
*
in
,
int
heads
,
int
seq_length
,
int
hidden_dim
)
{
#if __CUDA_ARCH__ >= 700
__shared__
float4
in_data
[
3072
];
int
d0_stride
=
hidden_dim
*
seq_length
;
int
d1_stride
=
hidden_dim
;
int
d2_stride
=
hidden_dim
/
heads
;
int
d0
=
blockIdx
.
x
;
// Batch
int
d1
=
threadIdx
.
y
;
// Head
int
d2
=
blockIdx
.
y
;
// Sequence
int
cnt
=
threadIdx
.
z
;
// Hidden count
int
d3
=
threadIdx
.
x
;
// Values (groups of 8)
const
float4
*
in_vec
=
reinterpret_cast
<
const
float4
*>
(
in
);
float4
*
out_vec
=
reinterpret_cast
<
float4
*>
(
out
);
int
input_offset
=
d0
*
d0_stride
+
d2
*
(
d2_stride
<<
1
)
+
d3
+
(
d1
%
2
)
*
d2_stride
;
int
head_count
=
(
d1
>>
1
)
+
cnt
*
(
blockDim
.
y
>>
1
);
int
iteration_stride
=
blockDim
.
z
*
(
blockDim
.
y
>>
1
);
int
matrix_stride
=
(
d0_stride
*
gridDim
.
x
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_row
=
iter
*
iteration_stride
+
head_count
;
int
iter_offset
=
(
iter_row
%
blockDim
.
y
)
*
d2_stride
;
in_data
[
d3
+
iter_offset
+
(
iter_row
/
blockDim
.
y
+
(
d1
%
2
)
*
blockDim
.
z
)
*
d1_stride
]
=
in_vec
[
input_offset
+
iter_offset
*
seq_length
+
(
iter_row
/
blockDim
.
y
)
*
matrix_stride
];
}
__syncthreads
();
iteration_stride
=
d1_stride
*
blockDim
.
z
;
int
iter_index
=
cnt
*
d1_stride
+
d1
*
d2_stride
+
d3
;
int
output_offset
=
d0
*
d0_stride
*
blockDim
.
z
+
d2
*
(
iteration_stride
<<
1
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
2
;
iter
++
)
{
int
iter_id
=
iter
*
iteration_stride
+
iter_index
;
out_vec
[
output_offset
+
iter_id
]
=
in_data
[
iter_id
];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template
<
>
void
launch_transform4d_0213
<
float
>
(
float
*
out
,
const
float
*
in
,
int
batch_size
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
2
;
dim3
grid_dims
(
batch_size
,
heads
*
((
seq_length
-
1
)
/
8
+
1
),
trans_count
);
dim3
block_dims
(
hidden_dim
/
heads
,
8
);
transform4d_0213
<
float
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
,
1
);
}
template
<
>
void
launch_transform4d_0213
<
__half
>
(
__half
*
out
,
const
__half
*
in
,
int
batch_size
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
stream
,
int
trans_count
)
{
hidden_dim
>>=
3
;
if
(
hidden_dim
>
128
||
hidden_dim
<
16
)
{
int
head_ext
=
(
hidden_dim
-
1
)
/
MAX_THREADS
+
1
;
dim3
grid_dims
(
batch_size
,
trans_count
,
(
seq_length
*
head_ext
));
dim3
block_dims
(
hidden_dim
/
heads
,
(
heads
/
head_ext
));
transform4d_0213
<
__half
><<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
,
head_ext
);
}
else
{
dim3
grid_dims
(
batch_size
,
seq_length
/
2
);
dim3
block_dims
(
hidden_dim
/
heads
,
heads
,
trans_count
);
transform4d_0213_v2
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
out
,
in
,
heads
,
seq_length
,
hidden_dim
);
}
}
deepspeed/ops/csrc/utils/flatten_unflatten.cpp
0 → 100644
View file @
eadbbe09
/*
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at
::
Tensor
flatten
(
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
flatten_dense_tensors
(
tensors
);
}
std
::
vector
<
at
::
Tensor
>
unflatten
(
at
::
Tensor
flat
,
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
unflatten_dense_tensors
(
flat
,
tensors
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"flatten"
,
&
flatten
,
"Flatten dense tensors"
);
m
.
def
(
"unflatten"
,
&
unflatten
,
"Unflatten dense tensors"
);
}
deepspeed/ops/csrc/utils/hip/flatten_unflatten.cpp
0 → 100644
View file @
eadbbe09
/*
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at
::
Tensor
flatten
(
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
flatten_dense_tensors
(
tensors
);
}
std
::
vector
<
at
::
Tensor
>
unflatten
(
at
::
Tensor
flat
,
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
unflatten_dense_tensors
(
flat
,
tensors
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"flatten"
,
&
flatten
,
"Flatten dense tensors"
);
m
.
def
(
"unflatten"
,
&
unflatten
,
"Unflatten dense tensors"
);
}
deepspeed/ops/op_builder
deleted
120000 → 0
View file @
ab5534fc
../../op_builder
\ No newline at end of file
deepspeed/ops/op_builder/__init__.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
from
.cpu_adam
import
CPUAdamBuilder
from
.fused_adam
import
FusedAdamBuilder
from
.fused_lamb
import
FusedLambBuilder
from
.sparse_attn
import
SparseAttnBuilder
from
.transformer
import
TransformerBuilder
from
.stochastic_transformer
import
StochasticTransformerBuilder
from
.utils
import
UtilsBuilder
from
.builder
import
get_default_compute_capatabilities
# TODO: infer this list instead of hard coded
# List of all available ops
__op_builders__
=
[
CPUAdamBuilder
(),
FusedAdamBuilder
(),
FusedLambBuilder
(),
SparseAttnBuilder
(),
TransformerBuilder
(),
StochasticTransformerBuilder
(),
UtilsBuilder
()
]
ALL_OPS
=
{
op
.
name
:
op
for
op
in
__op_builders__
}
deepspeed/ops/op_builder/builder.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import
os
import
time
import
torch
import
importlib
from
pathlib
import
Path
import
subprocess
from
abc
import
ABC
,
abstractmethod
YELLOW
=
'
\033
[93m'
END
=
'
\033
[0m'
WARNING
=
f
"
{
YELLOW
}
[WARNING]
{
END
}
"
DEFAULT_TORCH_EXTENSION_PATH
=
"/tmp/torch_extensions"
DEFAULT_COMPUTE_CAPABILITIES
=
"6.0;6.1;7.0"
def
installed_cuda_version
():
import
torch.utils.cpp_extension
cuda_home
=
torch
.
utils
.
cpp_extension
.
CUDA_HOME
assert
cuda_home
is
not
None
,
"CUDA_HOME does not exist, unable to compile CUDA op(s)"
# Ensure there is not a cuda version mismatch between torch and nvcc compiler
output
=
subprocess
.
check_output
([
cuda_home
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output_split
=
output
.
split
()
release_idx
=
output_split
.
index
(
"release"
)
release
=
output_split
[
release_idx
+
1
].
replace
(
','
,
''
).
split
(
"."
)
# Ignore patch versions, only look at major + minor
cuda_major
,
cuda_minor
=
release
[:
2
]
installed_cuda_version
=
"."
.
join
(
release
[:
2
])
return
int
(
cuda_major
),
int
(
cuda_minor
)
def
get_default_compute_capatabilities
():
compute_caps
=
DEFAULT_COMPUTE_CAPABILITIES
import
torch.utils.cpp_extension
if
torch
.
utils
.
cpp_extension
.
CUDA_HOME
is
not
None
and
installed_cuda_version
(
)[
0
]
>=
11
:
if
installed_cuda_version
()[
0
]
==
11
and
installed_cuda_version
()[
1
]
==
0
:
# Special treatment of CUDA 11.0 because compute_86 is not supported.
compute_caps
+=
";8.0"
else
:
compute_caps
+=
";8.0;8.6"
return
compute_caps
def
assert_no_cuda_mismatch
():
cuda_major
,
cuda_minor
=
installed_cuda_version
()
sys_cuda_version
=
f
'
{
cuda_major
}
.
{
cuda_minor
}
'
torch_cuda_version
=
"."
.
join
(
torch
.
version
.
cuda
.
split
(
'.'
)[:
2
])
# This is a show-stopping error, should probably not proceed past this
if
sys_cuda_version
!=
torch_cuda_version
:
if
sys_cuda_version
==
"11.1"
and
torch_cuda_version
==
"11.0"
:
# it works to build against installed cuda-11.1 while torch was built with cuda-11.0
return
raise
Exception
(
f
"Installed CUDA version
{
sys_cuda_version
}
does not match the "
f
"version torch was compiled with
{
torch
.
version
.
cuda
}
, unable to compile "
"cuda/cpp extensions without a matching cuda version."
)
def
assert_torch_info
(
torch_info
):
install_torch_version
=
torch_info
[
'version'
]
install_cuda_version
=
torch_info
[
'cuda_version'
]
current_cuda_version
=
"."
.
join
(
torch
.
version
.
cuda
.
split
(
'.'
)[:
2
])
current_torch_version
=
"."
.
join
(
torch
.
__version__
.
split
(
'.'
)[:
2
])
if
install_cuda_version
!=
current_cuda_version
or
install_torch_version
!=
current_torch_version
:
raise
RuntimeError
(
"PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
"with a different version than what is being used at runtime. Please re-install "
f
"DeepSpeed or switch torch versions. DeepSpeed install versions: "
f
"torch=
{
install_torch_version
}
, cuda=
{
install_cuda_version
}
, runtime versions:"
f
"torch=
{
current_torch_version
}
, cuda=
{
current_cuda_version
}
"
)
class
OpBuilder
(
ABC
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
jit_mode
=
False
@
abstractmethod
def
absolute_name
(
self
):
'''
Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
will be installed as something like: deepspeed/ops/adam/cpu_adam.so
'''
pass
@
abstractmethod
def
sources
(
self
):
'''
Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
'''
pass
def
include_paths
(
self
):
'''
Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
'''
return
[]
def
nvcc_args
(
self
):
'''
Returns optional list of compiler flags to forward to nvcc when building CUDA sources
'''
return
[]
def
cxx_args
(
self
):
'''
Returns optional list of compiler flags to forward to the build
'''
return
[]
def
is_compatible
(
self
):
'''
Check if all non-python dependencies are satisfied to build this op
'''
return
True
def
extra_ldflags
(
self
):
return
[]
def
libraries_installed
(
self
,
libraries
):
valid
=
False
check_cmd
=
'dpkg -l'
for
lib
in
libraries
:
result
=
subprocess
.
Popen
(
f
'dpkg -l
{
lib
}
'
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
)
valid
=
valid
or
result
.
wait
()
==
0
return
valid
def
simd_width
(
self
):
if
not
self
.
command_exists
(
'lscpu'
):
self
.
warning
(
f
"
{
self
.
name
}
is attempted to query 'lscpu' to detect the existence "
"of AVX instructions. However, 'lscpu' does not appear to exist on "
"your system, will fall back to non-vectorized execution."
)
return
''
result
=
subprocess
.
check_output
(
'lscpu'
,
shell
=
True
)
result
=
result
.
decode
(
'utf-8'
).
strip
().
lower
()
if
'genuineintel'
in
result
:
if
'avx512'
in
result
:
return
'-D__AVX512__'
elif
'avx2'
in
result
:
return
'-D__AVX256__'
return
''
def
python_requirements
(
self
):
'''
Override if op wants to define special dependencies, otherwise will
take self.name and load requirements-<op-name>.txt if it exists.
'''
path
=
f
'requirements/requirements-
{
self
.
name
}
.txt'
requirements
=
[]
if
os
.
path
.
isfile
(
path
):
with
open
(
path
,
'r'
)
as
fd
:
requirements
=
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
return
requirements
def
command_exists
(
self
,
cmd
):
if
'|'
in
cmd
:
cmds
=
cmd
.
split
(
"|"
)
else
:
cmds
=
[
cmd
]
valid
=
False
for
cmd
in
cmds
:
result
=
subprocess
.
Popen
(
f
'type
{
cmd
}
'
,
stdout
=
subprocess
.
PIPE
,
shell
=
True
)
valid
=
valid
or
result
.
wait
()
==
0
if
not
valid
and
len
(
cmds
)
>
1
:
print
(
f
"
{
WARNING
}
{
self
.
name
}
requires one of the following commands '
{
cmds
}
', but it does not exist!"
)
elif
not
valid
and
len
(
cmds
)
==
1
:
print
(
f
"
{
WARNING
}
{
self
.
name
}
requires the '
{
cmd
}
' command, but it does not exist!"
)
return
valid
def
warning
(
self
,
msg
):
print
(
f
"
{
WARNING
}
{
msg
}
"
)
def
deepspeed_src_path
(
self
,
code_path
):
if
os
.
path
.
isabs
(
code_path
):
return
code_path
else
:
return
os
.
path
.
join
(
Path
(
__file__
).
parent
.
parent
.
absolute
(),
code_path
)
def
builder
(
self
):
from
torch.utils.cpp_extension
import
CppExtension
return
CppExtension
(
name
=
self
.
absolute_name
(),
sources
=
self
.
sources
(),
include_dirs
=
self
.
include_paths
(),
extra_compile_args
=
{
'cxx'
:
self
.
cxx_args
()},
extra_link_args
=
self
.
extra_ldflags
())
def
load
(
self
,
verbose
=
True
):
from
...git_version_info
import
installed_ops
,
torch_info
if
installed_ops
[
self
.
name
]:
# Ensure the op we're about to load was compiled with the same
# torch/cuda versions we are currently using at runtime.
if
isinstance
(
self
,
CUDAOpBuilder
):
assert_torch_info
(
torch_info
)
return
importlib
.
import_module
(
self
.
absolute_name
())
else
:
return
self
.
jit_load
(
verbose
)
def
jit_load
(
self
,
verbose
=
True
):
if
not
self
.
is_compatible
():
raise
RuntimeError
(
f
"Unable to JIT load the
{
self
.
name
}
op due to it not being compatible due to hardware/software issue."
)
try
:
import
ninja
except
ImportError
:
raise
RuntimeError
(
f
"Unable to JIT load the
{
self
.
name
}
op due to ninja not being installed."
)
if
isinstance
(
self
,
CUDAOpBuilder
):
assert_no_cuda_mismatch
()
self
.
jit_mode
=
True
from
torch.utils.cpp_extension
import
load
# Ensure directory exists to prevent race condition in some cases
ext_path
=
os
.
path
.
join
(
os
.
environ
.
get
(
'TORCH_EXTENSIONS_DIR'
,
DEFAULT_TORCH_EXTENSION_PATH
),
self
.
name
)
os
.
makedirs
(
ext_path
,
exist_ok
=
True
)
start_build
=
time
.
time
()
op_module
=
load
(
name
=
self
.
name
,
sources
=
[
self
.
deepspeed_src_path
(
path
)
for
path
in
self
.
sources
()],
extra_include_paths
=
[
self
.
deepspeed_src_path
(
path
)
for
path
in
self
.
include_paths
()
],
extra_cflags
=
self
.
cxx_args
(),
extra_cuda_cflags
=
self
.
nvcc_args
(),
extra_ldflags
=
self
.
extra_ldflags
(),
verbose
=
verbose
)
build_duration
=
time
.
time
()
-
start_build
if
verbose
:
print
(
f
"Time to load
{
self
.
name
}
op:
{
build_duration
}
seconds"
)
return
op_module
class
CUDAOpBuilder
(
OpBuilder
):
def
compute_capability_args
(
self
,
cross_compile_archs
=
None
):
"""
Returns nvcc compute capability compile flags.
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
Format:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
"""
ccs
=
[]
if
self
.
jit_mode
:
# Compile for underlying architectures since we know those at runtime
for
i
in
range
(
torch
.
cuda
.
device_count
()):
CC_MAJOR
,
CC_MINOR
=
torch
.
cuda
.
get_device_capability
(
i
)
cc
=
f
"
{
CC_MAJOR
}
.
{
CC_MINOR
}
"
if
cc
not
in
ccs
:
ccs
.
append
(
cc
)
ccs
=
sorted
(
ccs
)
ccs
[
-
1
]
+=
'+PTX'
else
:
# Cross-compile mode, compile for various architectures
# env override takes priority
cross_compile_archs_env
=
os
.
environ
.
get
(
'TORCH_CUDA_ARCH_LIST'
,
None
)
if
cross_compile_archs_env
is
not
None
:
if
cross_compile_archs
is
not
None
:
print
(
f
"
{
WARNING
}
env var `TORCH_CUDA_ARCH_LIST=
{
cross_compile_archs_env
}
` overrides `cross_compile_archs=
{
cross_compile_archs
}
`"
)
cross_compile_archs
=
cross_compile_archs_env
.
replace
(
' '
,
';'
)
else
:
if
cross_compile_archs
is
None
:
cross_compile_archs
=
get_default_compute_capatabilities
()
ccs
=
cross_compile_archs
.
split
(
';'
)
args
=
[]
for
cc
in
ccs
:
num
=
cc
[
0
]
+
cc
[
2
]
args
.
append
(
f
'-gencode=arch=compute_
{
num
}
,code=sm_
{
num
}
'
)
if
cc
.
endswith
(
'+PTX'
):
args
.
append
(
f
'-gencode=arch=compute_
{
num
}
,code=compute_
{
num
}
'
)
return
args
def
version_dependent_macros
(
self
):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
version_ge_1_1
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
0
):
version_ge_1_1
=
[
'-DVERSION_GE_1_1'
]
version_ge_1_3
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
version_ge_1_3
=
[
'-DVERSION_GE_1_3'
]
version_ge_1_5
=
[]
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
4
):
version_ge_1_5
=
[
'-DVERSION_GE_1_5'
]
return
version_ge_1_1
+
version_ge_1_3
+
version_ge_1_5
def
is_compatible
(
self
):
return
super
().
is_compatible
()
def
builder
(
self
):
from
torch.utils.cpp_extension
import
CUDAExtension
assert_no_cuda_mismatch
()
return
CUDAExtension
(
name
=
self
.
absolute_name
(),
sources
=
self
.
sources
(),
include_dirs
=
self
.
include_paths
(),
extra_compile_args
=
{
'cxx'
:
self
.
cxx_args
(),
'nvcc'
:
self
.
nvcc_args
()
})
deepspeed/ops/op_builder/cpu_adam.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import
os
import
torch
import
subprocess
from
.builder
import
CUDAOpBuilder
class
CPUAdamBuilder
(
CUDAOpBuilder
):
BUILD_VAR
=
"DS_BUILD_CPU_ADAM"
NAME
=
"cpu_adam"
def
__init__
(
self
):
super
().
__init__
(
name
=
self
.
NAME
)
def
absolute_name
(
self
):
return
f
'deepspeed.ops.adam.
{
self
.
NAME
}
_op'
def
sources
(
self
):
return
[
'csrc/adam/cpu_adam.cpp'
,
'csrc/adam/custom_cuda_kernel.cu'
]
def
include_paths
(
self
):
CUDA_INCLUDE
=
os
.
path
.
join
(
torch
.
utils
.
cpp_extension
.
CUDA_HOME
,
"include"
)
return
[
'csrc/includes'
,
CUDA_INCLUDE
]
def
simd_width
(
self
):
if
not
self
.
command_exists
(
'lscpu'
):
self
.
warning
(
"CPUAdam attempted to query 'lscpu' to detect the existence "
"of AVX instructions. However, 'lscpu' does not appear to exist on "
"your system, will fall back to non-vectorized execution."
)
return
''
result
=
subprocess
.
check_output
(
'lscpu'
,
shell
=
True
)
result
=
result
.
decode
(
'utf-8'
).
strip
().
lower
()
if
'genuineintel'
in
result
:
if
'avx512'
in
result
:
return
'-D__AVX512__'
elif
'avx2'
in
result
:
return
'-D__AVX256__'
return
'-D__SCALAR__'
def
cxx_args
(
self
):
CUDA_LIB64
=
os
.
path
.
join
(
torch
.
utils
.
cpp_extension
.
CUDA_HOME
,
"lib64"
)
SIMD_WIDTH
=
self
.
simd_width
()
return
[
'-O3'
,
'-std=c++14'
,
f
'-L
{
CUDA_LIB64
}
'
,
'-lcudart'
,
'-lcublas'
,
'-g'
,
'-Wno-reorder'
,
'-march=native'
,
'-fopenmp'
,
SIMD_WIDTH
]
def
nvcc_args
(
self
):
args
=
[
'-O3'
,
'--use_fast_math'
,
'-std=c++14'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'-U__CUDA_NO_HALF2_OPERATORS__'
]
args
+=
self
.
compute_capability_args
()
return
args
deepspeed/ops/op_builder/fused_adam.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import
torch
from
.builder
import
CUDAOpBuilder
class
FusedAdamBuilder
(
CUDAOpBuilder
):
BUILD_VAR
=
"DS_BUILD_FUSED_ADAM"
NAME
=
"fused_adam"
def
__init__
(
self
):
super
().
__init__
(
name
=
self
.
NAME
)
def
absolute_name
(
self
):
return
f
'deepspeed.ops.adam.
{
self
.
NAME
}
_op'
def
sources
(
self
):
return
[
'csrc/adam/fused_adam_frontend.cpp'
,
'csrc/adam/multi_tensor_adam.cu'
]
def
include_paths
(
self
):
return
[
'csrc/includes'
]
def
cxx_args
(
self
):
return
[
'-O3'
]
+
self
.
version_dependent_macros
()
def
nvcc_args
(
self
):
return
[
'-lineinfo'
,
'-O3'
,
'--use_fast_math'
]
+
self
.
version_dependent_macros
()
+
self
.
compute_capability_args
()
deepspeed/ops/op_builder/fused_lamb.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import
torch
from
.builder
import
CUDAOpBuilder
class
FusedLambBuilder
(
CUDAOpBuilder
):
BUILD_VAR
=
'DS_BUILD_FUSED_LAMB'
NAME
=
"fused_lamb"
def
__init__
(
self
):
super
().
__init__
(
name
=
self
.
NAME
)
def
absolute_name
(
self
):
return
f
'deepspeed.ops.lamb.
{
self
.
NAME
}
_op'
def
sources
(
self
):
return
[
'csrc/lamb/fused_lamb_cuda.cpp'
,
'csrc/lamb/fused_lamb_cuda_kernel.cu'
]
def
include_paths
(
self
):
return
[
'csrc/includes'
]
def
cxx_args
(
self
):
return
[
'-O3'
]
+
self
.
version_dependent_macros
()
def
nvcc_args
(
self
):
return
[
'-lineinfo'
,
'-O3'
,
'--use_fast_math'
]
+
self
.
version_dependent_macros
()
+
self
.
compute_capability_args
()
deepspeed/ops/op_builder/sparse_attn.py
0 → 100644
View file @
eadbbe09
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import
torch
import
warnings
from
.builder
import
OpBuilder
class
SparseAttnBuilder
(
OpBuilder
):
BUILD_VAR
=
"DS_BUILD_SPARSE_ATTN"
NAME
=
"sparse_attn"
def
__init__
(
self
):
super
().
__init__
(
name
=
self
.
NAME
)
def
absolute_name
(
self
):
return
f
'deepspeed.ops.sparse_attention.
{
self
.
NAME
}
_op'
def
sources
(
self
):
return
[
'csrc/sparse_attention/utils.cpp'
]
def
cxx_args
(
self
):
return
[
'-O2'
,
'-fopenmp'
]
def
is_compatible
(
self
):
# Check to see if llvm and cmake are installed since they are dependencies
required_commands
=
[
'llvm-config|llvm-config-9'
,
'cmake'
]
command_status
=
list
(
map
(
self
.
command_exists
,
required_commands
))
deps_compatible
=
all
(
command_status
)
# torch-cpu will not have a cuda version
if
torch
.
version
.
cuda
is
None
:
cuda_compatible
=
False
self
.
warning
(
f
"
{
self
.
NAME
}
cuda is not available from torch"
)
else
:
major
,
minor
=
torch
.
version
.
cuda
.
split
(
'.'
)[:
2
]
cuda_compatible
=
int
(
major
)
==
10
and
int
(
minor
)
>=
1
if
not
cuda_compatible
:
self
.
warning
(
f
"
{
self
.
NAME
}
requires CUDA version 10.1+, does not currently support >=11 or <10.1"
)
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
torch_compatible
=
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
5
if
not
torch_compatible
:
self
.
warning
(
f
'
{
self
.
NAME
}
requires a torch version >= 1.5 but detected
{
TORCH_MAJOR
}
.
{
TORCH_MINOR
}
'
)
return
super
().
is_compatible
(
)
and
deps_compatible
and
torch_compatible
and
cuda_compatible
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment