Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int m;
int n;
int k;
float alpha;
float beta;
cublasOperation_t op_A;
cublasOperation_t op_B;
std::array<int, 3> gemm_algos;
Config(float param_alpha, float param_beta, cublasOperation_t opA,
cublasOperation_t opB)
: alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
void SetConfig(int mm, int nn, int kk) {
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config &config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
cublasHandle_t handle) {
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void Backward(int bsz, const T *d_output, const T *_buffer_a,
const T *_buffer_b, cublasHandle_t handle,
T *inpGradA = nullptr, T *inpGradB = nullptr) {
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
cublasOperation_t op_b =
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
// Calculate d_A.
cublas_strided_batched_gemm(
handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
cublasOperation_t op_a =
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
};
#include <cooperative_groups.h>
#include "block_reduce.h"
#include "kernels.h"
namespace cg = cooperative_groups;
const float LN_EPSILON = 1e-8f;
#define TILE_DIM 32
template <typename T>
__forceinline__ __device__ T add_eps(T x) {
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
}
/**
@brief: ker_layer_norm
Standard layer normalization.
It will not only output the layer norm result,
but also outputs variance.
may also output means, depends on whether
the means argument is nullptr
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
ln_res: [batch_size* seq_len, hidden_size], ln result.
vars: [batch_size* seq_len], variance per token
means: [batch_size* seq_len], means per token, can be nullput
inp: [batch_size * seq_len, hidden_size], ln input.
scale: [hidden_size], ln scale
bias: [hidden_size], ln bias
*/
template <typename T>
__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
const T *scale, const T *bias, int hidden_size) {
// step 0. compute local sum
float l_sum = 0;
float l_square_sum = 0;
const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 val = inp_f4[idx];
l_sum += val.x + val.y + val.z + val.w;
l_square_sum +=
val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w;
}
// step 1. compute reduce sum
float mean_dim = float(hidden_size) * 4.f;
float reduce_val[2] = {l_sum, l_square_sum};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = reduce_val[0] / mean_dim;
if (means != nullptr) {
means[blockIdx.x] = s_mean;
}
s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
vars[blockIdx.x] = s_var;
s_var = rsqrtf(s_var);
}
__syncthreads();
// step 2. layer norm result
float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 vscale = __ldg((const float4 *)scale + idx);
float4 vbias = __ldg((const float4 *)bias + idx);
float4 val = inp_f4[idx];
val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x;
val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y;
val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z;
val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w;
output_f4[idx] = val;
}
}
template <>
__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
__half *means, const __half *inp,
const __half *scale, const __half *bias,
int hidden_size) {
// step 0. compute local sum
float l_sum = 0;
float l_square_sum = 0;
const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 val_f4 = inp_f4[idx];
__half2 *val_h2 = (__half2 *)(&val_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 val_f2 = __half22float2(val_h2[i]);
l_sum += val_f2.x + val_f2.y;
l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
}
}
// step 1. compute reduce sum
float mean_dim = float(hidden_size) * 8.f;
float reduce_val[2] = {l_sum, l_square_sum};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = reduce_val[0] / mean_dim;
if (means != nullptr) {
means[blockIdx.x] = s_mean;
}
s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
vars[blockIdx.x] = s_var;
s_var = rsqrtf(s_var);
}
__syncthreads();
// step 2. layer norm result
float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
// load scale, bias, input
float4 scale_f4 = __ldg((const float4 *)scale + idx);
__half2 *scale_h2 = (__half2 *)(&scale_f4);
float4 bias_f4 = __ldg((const float4 *)bias + idx);
__half2 *bias_h2 = (__half2 *)(&bias_f4);
float4 val_f4 = inp_f4[idx];
__half2 *val_h2 = (__half2 *)(&val_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 scale_f2 = __half22float2(scale_h2[i]);
float2 bias_f2 = __half22float2(bias_h2[i]);
float2 val_f2 = __half22float2(val_h2[i]);
val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
val_h2[i] = __float22half2_rn(val_f2);
}
output_f4[idx] = val_f4;
}
}
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 2;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// }
// }
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
// val_f2_3.y * val_f2_3.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 4;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2);
// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2);
// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3);
// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2);
// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2);
// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3);
// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// float4 val_f4_2 = inp_f4[idx+2];
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 scale_f2_2 = __half22float2(scale_h2_2[i]);
// float2 scale_f2_3 = __half22float2(scale_h2_3[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 bias_f2_2 = __half22float2(bias_h2_2[i]);
// float2 bias_f2_3 = __half22float2(bias_h2_3[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var
// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) *
// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean)
// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// output_f4[idx+2] = val_f4_2;
// output_f4[idx+3] = val_f4_3;
// }
// }
template <>
void launch_layer_norm<float>(float *ln_res, float *vars, float *means,
const float *inp, const float *scale,
const float *bias, int batch_size, int hidden_dim,
cudaStream_t stream) {
if (hidden_dim % 4 != 0) {
throw std::runtime_error("violate hidden_dim % 4 = 0");
}
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
dim3 grid_dim(batch_size);
dim3 block_dim(nthread);
ker_layer_norm<float><<<grid_dim, block_dim, 0, stream>>>(
ln_res, vars, means, inp, scale, bias, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means,
const __half *inp, const __half *scale,
const __half *bias, int batch_size,
int hidden_dim, cudaStream_t stream) {
if (hidden_dim % 8 != 0) {
throw std::runtime_error("violate hidden_dim % 8 = 0");
}
hidden_dim >>= 3;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
dim3 grid_dim(batch_size);
dim3 block_dim(nthread);
ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
ln_res, vars, means, inp, scale, bias, hidden_dim);
// if (hidden_dim % 8 != 0) {
// throw std::runtime_error("violate hidden_dim % 8 = 0");
// }
// hidden_dim >>= 3;
// if (hidden_dim * 8 < 8192) {
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) {
// hidden_dim >>= 1;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x2<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) {
// hidden_dim >>= 2;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x4<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else {
// throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
// }
}
/**
@brief: ker_ln_bw_dgamma_dbetta
Layer norm backword kernel, compute the gradient of gamma and betta.
dbetta = sum(dout, dim=0)
dgamma = sum(xhat * dout, dim=0)
xhat = (input - mean) * rsqrt(var) or
(output - betta) / gamma
@thread
gridDim.x = hidden_size / 32
blockDim.x = 32
blockDim.y = 32
@param
gamma_grad: [hidden_size], gradient of gamma
betta_grad: [hidden_size], gradient of betta
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat, maybe nullptr
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat, maybe nullptr
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
(gamma && betta) ^ (vars && means) should be true
*/
template <typename T>
__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
const T *out_grad, const T *inp_or_out,
const T *gamma, const T *betta,
const T *vars, const T *means, int rows,
int width) {
__shared__ float betta_buffer[TILE_DIM][TILE_DIM];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
// Loop across inp height
float dbetta = 0;
float dgamma = 0;
float dout, val;
if (idx < width) {
if (means == nullptr) {
float vbetta = (float)betta[idx];
float vgamma = (float)gamma[idx];
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is output
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - vbetta) / add_eps(vgamma) * dout);
offset += y_stride;
}
} else {
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is input
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - (float)means[r]) *
rsqrtf((float)vars[r] + LN_EPSILON) * dout);
offset += y_stride;
}
}
}
// Sum the shared buffer.
betta_buffer[threadIdx.x][threadIdx.y] = dbetta;
gamma_buffer[threadIdx.x][threadIdx.y] = dgamma;
__syncthreads();
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (threadIdx.x == 0 && idx < width) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/**
@brief: ker_ln_bw_dinp
Layer norm backword kernel, compute the gradient of input.
dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim)
* rsqrt(var)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dxhat = dout * gamma
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input,
usually appear in pre-layer-norm for transformer layer, maybe nullptr
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat and dxhat
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat and dinp
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
*/
template <typename T>
__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out,
const T *gamma, const T *betta, const T *vars,
const T *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;
float4 dxhat, xhat;
float var_rsqrt;
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
dxhat = ((const float4 *)out_grad)[offset];
float4 vgamma = ((const float4 *)gamma)[threadIdx.x];
dxhat.x *= vgamma.x;
dxhat.y *= vgamma.y;
dxhat.z *= vgamma.z;
dxhat.w *= vgamma.w;
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
xhat = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x);
xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y);
xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z);
xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w);
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
xhat.x = (xhat.x - fmean) * var_rsqrt;
xhat.y = (xhat.y - fmean) * var_rsqrt;
xhat.z = (xhat.z - fmean) * var_rsqrt;
xhat.w = (xhat.w - fmean) * var_rsqrt;
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w;
reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z +
dxhat.w * xhat.w;
}
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 4;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt;
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
dxhat.x += dresidual.x;
dxhat.y += dresidual.y;
dxhat.z += dresidual.z;
dxhat.w += dresidual.w;
}
((float4 *)inp_grad)[offset] = dxhat;
}
template <>
__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;
float2 dxhat[4], xhat[4];
float var_rsqrt;
float4 vtmp;
__half2 *tmp_h2;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vbetta = __half22float2(betta_h2[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
}
__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out, const __half *gamma,
const __half *betta, const __half *vars,
const __half *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2;
float2 dxhat[4], xhat[4];
float2 dxhat_1[4], xhat_1[4];
float var_rsqrt;
float4 vtmp, vtmp_1;
__half2 *tmp_h2;
__half2 *tmp_h2_1;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
vtmp_1 = ((const float4 *)out_grad)[offset + 1];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2];
float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
__half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vdout_1 = __half22float2(tmp_h2_1[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
dxhat_1[i].x = vdout_1.x * vgamma_1.x;
dxhat_1[i].y = vdout_1.y * vgamma_1.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
vtmp_1 = ((const float4 *)inp_or_out)[offset + 1];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x];
float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
__half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vout_1 = __half22float2(tmp_h2_1[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vbetta = __half22float2(betta_h2[i]);
float2 vbetta_1 = __half22float2(betta_h2_1[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
float2 vinp_1 = __half22float2(tmp_h2_1[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8 * 2;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
((float4 *)inp_grad)[offset + 1] = vtmp_1;
}
__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out, const __half *gamma,
const __half *betta, const __half *vars,
const __half *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4;
float2 dxhat[4], xhat[4];
float2 dxhat_1[4], xhat_1[4];
float2 dxhat_2[4], xhat_2[4];
float2 dxhat_3[4], xhat_3[4];
float var_rsqrt;
float4 vtmp, vtmp_1, vtmp_2, vtmp_3;
__half2 *tmp_h2;
__half2 *tmp_h2_1;
__half2 *tmp_h2_2;
__half2 *tmp_h2_3;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
vtmp_1 = ((const float4 *)out_grad)[offset + 1];
vtmp_2 = ((const float4 *)out_grad)[offset + 2];
vtmp_3 = ((const float4 *)out_grad)[offset + 3];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1);
tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2);
tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4];
float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1];
float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2];
float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
__half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1);
__half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2);
__half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vdout_1 = __half22float2(tmp_h2_1[i]);
float2 vdout_2 = __half22float2(tmp_h2_2[i]);
float2 vdout_3 = __half22float2(tmp_h2_3[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vgamma_2 = __half22float2(gamma_h2_2[i]);
float2 vgamma_3 = __half22float2(gamma_h2_3[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
dxhat_1[i].x = vdout_1.x * vgamma_1.x;
dxhat_1[i].y = vdout_1.y * vgamma_1.y;
dxhat_2[i].x = vdout_2.x * vgamma_2.x;
dxhat_2[i].y = vdout_2.y * vgamma_2.y;
dxhat_3[i].x = vdout_3.x * vgamma_3.x;
dxhat_3[i].y = vdout_3.y * vgamma_3.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y +
dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x +
dxhat_3[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
vtmp_1 = ((const float4 *)inp_or_out)[offset + 1];
vtmp_2 = ((const float4 *)inp_or_out)[offset + 2];
vtmp_3 = ((const float4 *)inp_or_out)[offset + 3];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x];
float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1];
float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2];
float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
__half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1);
__half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2);
__half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vout_1 = __half22float2(tmp_h2_1[i]);
float2 vout_2 = __half22float2(tmp_h2_2[i]);
float2 vout_3 = __half22float2(tmp_h2_3[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vgamma_2 = __half22float2(gamma_h2_2[i]);
float2 vgamma_3 = __half22float2(gamma_h2_3[i]);
float2 vbetta = __half22float2(betta_h2[i]);
float2 vbetta_1 = __half22float2(betta_h2_1[i]);
float2 vbetta_2 = __half22float2(betta_h2_2[i]);
float2 vbetta_3 = __half22float2(betta_h2_3[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x);
xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x);
xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y);
xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
float2 vinp_1 = __half22float2(tmp_h2_1[i]);
float2 vinp_2 = __half22float2(tmp_h2_2[i]);
float2 vinp_3 = __half22float2(tmp_h2_3[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt;
xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt;
xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt;
xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8 * 4;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2];
float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
__half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2);
__half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i]));
tmp_h2_2[i].x = __float2half(
(dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_2[2 * i]));
tmp_h2_3[i].x = __float2half(
(dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_3[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
tmp_h2_2[i].y = __float2half(
(dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
tmp_h2_3[i].y = __float2half(
(dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_2[i].x = __float2half(
(dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_3[i].x = __float2half(
(dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_2[i].y = __float2half(
(dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_3[i].y = __float2half(
(dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
((float4 *)inp_grad)[offset + 1] = vtmp_1;
((float4 *)inp_grad)[offset + 2] = vtmp_2;
((float4 *)inp_grad)[offset + 3] = vtmp_3;
}
/**
Layer norm backword,
compute the gradient of gamma, betta and input.
dbetta = sum(dout, dim=0)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dgamma = sum(xhat * dout, dim=0)
dxhat = dout * gamma
dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim)
* rsqrt(var)
residual_grad, means, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
means and betta are only used to compute xhat,
(means == nullptr) ^ (betta == nullptr) should be true
*/
template <>
void launch_ln_bw<float>(float *gamma_grad, float *betta_grad, float *inp_grad,
const float *out_grad, const float *residual_grad,
const float *inp_or_out, const float *gamma,
const float *betta, const float *vars,
const float *means, int batch, int hidden_dim,
cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<float><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);
// compute grad of input
if (hidden_dim % 4 != 0 || hidden_dim > 4096) {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096");
}
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
}
template <>
void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
__half *inp_grad, const __half *out_grad,
const __half *residual_grad, const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means, int batch,
int hidden_dim, cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);
// compute grad of input
if (hidden_dim % 8 != 0) {
throw std::runtime_error("hidden_dim % 8 != 0");
}
hidden_dim >>= 3;
if (hidden_dim * 8 <= 8192) {
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) {
hidden_dim >>= 1;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) {
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
}
}
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
namespace cg = cooperative_groups;
const float EPSILON = 1e-8f;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// block reduce max
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
// write shared
__shared__ float s_max[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_max[i] = l_max[i];
}
}
__syncthreads();
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - s_max[i]);
l_sum[i] += val[i][j];
}
}
// block reduce sum
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
// write shared
__shared__ float s_sum[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
}
}
__syncthreads();
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// warp reduce max
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - l_max[i]);
l_sum[i] += val[i][j];
}
}
// warp reduce sum
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <>
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 16;
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 32;
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 64;
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
template <>
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 8;
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 16;
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 32;
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template <typename T, int ITERATIONS>
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
inp += offset;
T grad_reg[ITERATIONS];
T inp_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
inp_reg[i] = inp[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)inp_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
}
}
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream) {
const int warps_per_block = 4;
// rows = batch_size * nhead * from_len
dim3 grid_dim(rows / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (softmax_len <= 32)
ker_attn_softmax_bw<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 64)
ker_attn_softmax_bw<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 128)
ker_attn_softmax_bw<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 256)
ker_attn_softmax_bw<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 384)
ker_attn_softmax_bw<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 512)
ker_attn_softmax_bw<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 768)
ker_attn_softmax_bw<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 1024)
ker_attn_softmax_bw<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 2048)
ker_attn_softmax_bw<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else
throw std::runtime_error(
std::string(
"Special sequence length found in softmax backward, seq_len: ") +
std::to_string(softmax_len));
}
template void launch_attn_softmax_bw<__half>(__half *out_grad,
const __half *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
template void launch_attn_softmax_bw<float>(float *out_grad,
const float *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using namespace cub;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template <typename T>
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
int head_dim);
template <>
__global__ void transform_0213<float>(float *output, const float *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
template <>
__global__ void transform_0213<__half>(__half *output, const __half *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
// [b, s, h] -> [b, nh, s, ad]
template <>
void launch_transform_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
template <>
void launch_transform_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template <typename T>
__global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4);
template <>
__global__ void bias_add_transform_20314<float>(float *output,
const float *input,
const float *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
vres4.x = vqkv4.x + vbias4.x;
vres4.y = vqkv4.y + vbias4.y;
vres4.z = vqkv4.z + vbias4.z;
vres4.w = vqkv4.w + vbias4.w;
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
template <>
__global__ void bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <>
void launch_bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 2;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
template <>
void launch_bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 3;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template <typename T>
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
int seq_len, int trans_count, int nhead,
int head_dim, int num_all) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset >= num_all) {
return;
}
int trans_id, batch_id, head_id, token_id, dim_id;
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
&batch_id, &head_id, &token_id, &dim_id);
// [b, s, tc, nh, ad]
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
seq_len, trans_count, nhead, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
res4[trg_offset] = input4[offset];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <>
void launch_transform4d_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
template <>
void launch_transform4d_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len,
int hidden_dim, int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
#include "multihead_attention_1d.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
#include "kernels.h"
template <typename T>
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_size,
int num_heads,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm)
: _layer_id(layer_id),
_max_batch_tokens(max_batch_tokens),
_max_seq_len(max_seq_len),
_hidden_size(hidden_size),
_heads(num_heads),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_qkv_linear(
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_out_linear(
typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
_max_batch_tokens),
_softmax(typename Softmax<T>::Config(num_heads)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
_max_batch_tokens * _heads * _max_seq_len),
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
_max_batch_tokens * _hidden_size),
_attn_scores(typename StridedBatchGemm<T>::Config(
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
CUBLAS_OP_N)),
_attn_context(typename StridedBatchGemm<T>::Config(
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
assert(_hidden_size % _heads == 0);
}
template <typename T>
MultiHeadAttention<T>::~MultiHeadAttention() {
free_mem_buffer();
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
const T *input_mask_ptr,
T *output_ptr, T *buffer) {
T *q_tf_ptr = _qkv_ptr;
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
if (_pre_or_postLayerNorm) {
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
_cublasHandle);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
_batch_size, _seq_len, 3, _heads / pg_size,
_hidden_size / _heads, _stream);
// attention scores, q*k
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle);
// Softmax + Mask
_softmax.reset_size(_heads / pg_size);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
_seq_len, _stream, true);
// attn prob dropout.
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
// attention context, score * v
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
_cublasHandle);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, 1,
_stream);
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
output_ptr, _cublasHandle);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto output_tensor = torch::from_blob(
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
_attn_ob_ptr, _batch_tokens, _hidden_size,
_stream);
if (!_pre_or_postLayerNorm) {
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
T *out_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer);
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
const T *input_mask_ptr,
const T *output_ptr,
const T *grad_output_ptr,
T *grad_input_ptr, T *buffer) {
cudaStream_t streams[2] = {_stream, _stream};
const T *q_tf_ptr = _qkv_ptr;
const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
// batch_dim = batch_size * seq_len * hidden_size
// buffer size: batch_dim * 3 + max(batch_dim * 3,
// batch_size * head_num * seq_len * seq_len)
T *grad_residual_ptr = buffer;
buffer += _batch_dim;
T *grad_input_buf_ptr = buffer; // batch_dim
T *grad_qkv_5d_ptr = buffer; // batch_dim * 3
buffer += 3 * _batch_dim / pg_size;
T *grad_qkv_4d_ptr = buffer; // batch_dim * 3
T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len
// buffer += max(3 * _batch_dim,
// batch_size * head_num * seq_len * seq_len);
if (_pre_or_postLayerNorm) {
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_output_ptr, _batch_tokens,
_hidden_size, _stream);
} else {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
_attn_nb_ptr, _batch_tokens, streams);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_residual_ptr, _batch_tokens,
_hidden_size, _stream);
}
// bw of output project
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
false);
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
_stream);
// bw of score * v
_attn_context.Backward(
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
_softmax.reset_size(_heads / pg_size);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
_seq_len, _stream);
// bw of q * k
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
grad_qkv_5d_ptr);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
3, _stream);
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
true);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto grad_input_tensor =
torch::from_blob(grad_input_buf_ptr,
{int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
if (_pre_or_postLayerNorm) {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
} else {
// FIXME later
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
_batch_size, _seq_len, _hidden_size, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr,
T *grad_input_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *buffer = _shared_mem_ptr;
/*
buffer size needed by attn bw:
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
grad_input_ptr, buffer);
}
template <typename T>
void MultiHeadAttention<T>::SetTrainingMode(bool training) {
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_dropout.SetTrainingMode(training);
}
template <typename T>
T *MultiHeadAttention<T>::_shared_mem_ptr = nullptr;
template class MultiHeadAttention<float>;
template class MultiHeadAttention<__half>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
template <typename T>
int create_multihead_attention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_dim, int num_heads,
float attn_prob_dropout_ratio,
float hidden_dropout_ratio,
bool pre_or_postLayerNorm,
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Context::Instance().set_stream(stream);
auto layer = std::make_shared<MultiHeadAttention<T>>(
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
layer->SetPG(pg_);
s_multihead_attention[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
return 0;
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_fw(
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
const T *input_ptr = (const T *)input.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
auto output = torch::empty_like(input);
T *out_ptr = (T *)output.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(input.size(0), input.size(1));
layer->SetTrainingMode(training_mode);
layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr();
layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr();
layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr();
layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr();
layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr();
layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr();
layer->Forward(input_ptr, input_mask_ptr, out_ptr);
return {output};
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_bw(
int layer_id, const torch::Tensor &grad_dec_output,
const torch::Tensor &output, const torch::Tensor &input,
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
auto g_output = grad_dec_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
auto grad_input = torch::empty_like(input);
auto grad_in_proj_weight = torch::empty_like(in_proj_weight);
auto grad_in_proj_bias = torch::empty_like(in_proj_bias);
auto grad_out_proj_weight = torch::empty_like(out_proj_weight);
auto grad_out_proj_bias = torch::empty_like(out_proj_bias);
auto grad_norm_weight = torch::empty_like(norm_weight);
auto grad_norm_bias = torch::empty_like(norm_bias);
// inputs.
const T *grad_dec_output_ptr = (const T *)g_output.data_ptr();
const T *input_ptr = (const T *)input.data_ptr();
const T *output_ptr = (const T *)output.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
// outputs.
T *grad_input_ptr = (T *)grad_input.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr();
layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr();
layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr();
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
grad_input_ptr);
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
grad_norm_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multihead_attention_fw_fp32", &multihead_attention_fw<float>,
"Multi-head Attention forward with fp32 (CUDA)");
m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>,
"Multi-head Attention forward with fp16 (CUDA)");
m.def("multihead_attention_bw_fp32", &multihead_attention_bw<float>,
"Multi-head Attention backward with fp32 (CUDA)");
m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>,
"Multi-head Attention backward with fp16 (CUDA)");
m.def("create_multihead_attention_fp32", &create_multihead_attention<float>,
"Create Multi-head Attention with fp32 (CUDA)");
m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>,
"Create Multi-head Attention with fp16 (CUDA)");
}
#pragma once
#include <c10/util/intrusive_ptr.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
#include "cuda_util.h"
#include "dropout.h"
#include "feed_forward.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"
template <typename T>
class MultiHeadAttention {
public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size;
_seq_len = seq_len;
_batch_tokens = batch_size * seq_len;
_batch_heads = batch_size * _heads / pg_size;
_batch_dim = _batch_tokens * _hidden_size;
_attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len);
}
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
void SetPG(c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
pg = pg_;
pg_size = 1;
if (pg != c10::detail::UniqueVoidPtr()) {
pg_size = pg->getSize();
}
allocate_mem_buffer();
}
// weights ptr
const T *_attn_qkvw_ptr;
const T *_attn_qkvb_ptr;
const T *_attn_ow_ptr;
const T *_attn_ob_ptr;
const T *_attn_nw_ptr;
const T *_attn_nb_ptr;
// grads ptr
T *_grad_attn_qkvw_ptr;
T *_grad_attn_qkvb_ptr;
T *_grad_attn_ow_ptr;
T *_grad_attn_ob_ptr;
T *_grad_attn_nw_ptr;
T *_grad_attn_nb_ptr;
private:
void allocate_mem_buffer() {
// allocate local gpu memory
if (_pre_or_postLayerNorm) {
_gemmQKV_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
} else {
_gemmQKV_inp_ptr = nullptr;
}
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw
size_t smem_size =
4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = cuda_malloc<T>(smem_size);
}
}
void free_mem_buffer() {
// free local gpu memory
cuda_free(_gemmQKV_inp_ptr);
cuda_free(_qkv_ptr);
cuda_free(_soft_out_ptr);
cuda_free(_ctx_bufB_ptr);
cuda_free(_attn_o_inp_ptr);
// free shared gpu memory between layers
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = nullptr;
}
// const parameter between batch
const size_t _layer_id;
const size_t _hidden_size;
const size_t _heads;
const size_t _max_batch_tokens;
const size_t _max_seq_len;
const bool _pre_or_postLayerNorm;
// dynamic parameter between batch
size_t _batch_size;
size_t _seq_len;
size_t _batch_tokens;
size_t _batch_heads;
size_t _batch_dim;
bool _training;
cublasHandle_t _cublasHandle;
cudaStream_t _stream;
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _attn_ln;
Softmax<T> _softmax;
Dropout<T> _attn_prob_dropout;
Dropout<T> _attn_dropout;
StridedBatchGemm<T> _attn_scores;
StridedBatchGemm<T> _attn_context;
// local GPU memory
T *_gemmQKV_inp_ptr;
T *_qkv_ptr;
T *_soft_out_ptr;
T *_ctx_bufB_ptr;
T *_attn_o_inp_ptr;
// shared GPU memory between layer
static T *_shared_mem_ptr;
c10::intrusive_ptr<c10d::ProcessGroup> pg;
int pg_size;
};
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);
from .mha import ColoAttention
__all__ = ["ColoAttention"]
import warnings
from typing import Optional
import torch
def is_ampere_or_better_gpu():
if torch.cuda.is_available():
device = torch.device("cuda")
properties = torch.cuda.get_device_properties(device)
if properties.major >= 8: # Ampere GPUs or newer
return True
return False
# "Check Ampere GPUs or newer"
HAS_FLASH_ATTN = False
if is_ampere_or_better_gpu():
HAS_FLASH_ATTN = True
else:
warnings.warn("FlashAttention only supports Ampere GPUs or newer.")
HAS_FLASH_ATTN = False
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
HAS_FLASH_ATTN = True
except ImportError:
warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention")
HAS_FLASH_ATTN = False
if HAS_FLASH_ATTN:
pass
from .utils import SeqLenInfo
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out
import warnings
HAS_MEM_EFF_ATTN = False
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True
except ImportError:
warnings.warn("please install xformers from https://github.com/facebookresearch/xformers")
HAS_MEM_EFF_ATTN = False
if HAS_MEM_EFF_ATTN:
"""
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
from typing import Optional
import torch
from .utils import SeqLenInfo
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.autograd import Function
def check_config(config):
if config.hidden_size % config.nhead != 0:
raise Exception("hidden_size % nhead != 0")
factor = 8 if config.fp16 else 4
upbound = factor * 1024 * 4
if config.hidden_size > upbound:
# as required by ln backward kernel currently
raise Exception(f"hidden_size > {upbound}")
head_dim = config.hidden_size // config.nhead
if head_dim % factor != 0:
# as required by reshape kernel
raise Exception(f"head_dim({head_dim}) % {factor} != 0")
def calc_offset(sizes):
offsets = [0]
tmp = 0
for x in sizes:
tmp += x
offsets.append(tmp)
return offsets
colossal_multihead_attention = None
@dataclass
class Config:
max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first
fp16: bool # fp16 precision
class MultiHeadAttention1DFunc(Function):
@staticmethod
def forward(
ctx,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config,
):
cuda_module = colossal_multihead_attention
forward_func = (
cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
)
if config.fp16:
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(output,) = forward_func(
config.layer_id,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config.training,
config.norm_first,
)
if config.is_grad_enabled and config.training:
ctx.save_for_backward(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
ctx.config = config
return output
@staticmethod
def backward(ctx, grad_output):
assert ctx.config.training
cuda_module = colossal_multihead_attention
backward_func = (
cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
)
(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
) = ctx.saved_tensors
grad_input = None
grad_in_proj_weight = None
grad_in_proj_bias = None
grad_out_proj_weight = None
grad_out_proj_bias = None
grad_norm_weight = None
grad_norm_bias = None
if ctx.config.fp16:
grad_output = grad_output.to(torch.half)
output = output.to(torch.half)
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(
grad_input,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
) = backward_func(
ctx.config.layer_id,
grad_output,
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
return (
grad_input,
None,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
None,
)
class MultiHeadAttention(nn.Module):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
batch_size: Batch Size for one forward
max_seq_len: Max length of input sequence
dropout: Dropout probability
norm_first: perform LayerNorms before attention
"""
layer_id = 0
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__()
self.config = Config(
batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
)
check_config(self.config)
self.pg = pg
self.pg_size = 1
if self.pg:
self.pg_size = pg.size()
self.config.layer_id = MultiHeadAttention.layer_id
MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
create_layer_func = (
cuda_module.create_multihead_attention_fp16
if self.config.fp16
else cuda_module.create_multihead_attention_fp32
)
create_layer_func(
self.config.layer_id,
self.config.max_batch_tokens,
self.config.max_seq_len,
self.config.hidden_size,
self.config.nhead,
self.config.attn_prob_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.norm_first,
self.pg,
)
hs = self.config.hidden_size
self.precision = torch.float32
if self.config.fp16:
self.precision = torch.half
self.hs_per_rank = int(hs / self.pg_size)
self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
self.norm_weight = nn.Parameter(torch.Tensor(hs))
self.norm_bias = nn.Parameter(torch.Tensor(hs))
self.reset_parameters()
torch.cuda.empty_cache()
def calc_bound(self, w):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
bound = 1.0 / math.sqrt(fan_in)
return bound
def reset_parameters(self):
hs = self.config.hidden_size
nn.init.zeros_(self.out_proj_bias)
nn.init.ones_(self.norm_weight)
nn.init.zeros_(self.norm_bias)
if self.pg_size > 1:
rank_in_pg = torch.distributed.get_rank(self.pg)
attn_qkvw_global = torch.empty(hs * 3, hs)
attn_qkvb_global = torch.empty(hs * 3)
nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw_global)
nn.init.uniform_(attn_qkvb_global, -bound, bound)
attn_qkvw_global = attn_qkvw_global.cuda()
attn_qkvb_global = attn_qkvb_global.cuda()
torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
attn_qkvw_global = attn_qkvw_global.cpu()
attn_qkvb_global = attn_qkvb_global.cpu()
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
]
)
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
]
)
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
attn_ow_global = attn_ow_global.cuda()
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(
attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
)
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)
nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw)
nn.init.uniform_(self.in_proj_bias, -bound, bound)
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
return destination
def forward(self, hidden_states, encoder_padding_mask):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
hidden_states = hidden_states.contiguous()
encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens:
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
if sl > self.config.max_seq_len:
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
if len(encoder_padding_mask.size()) == 1:
assert bs == 1 and sl == encoder_padding_mask.size(0)
else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(
hidden_states,
encoder_padding_mask,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.norm_weight,
self.norm_bias,
self.config,
)
return output.to(self.precision)
../../extensions
\ No newline at end of file
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
......@@ -46,11 +46,13 @@ def warmup_jit_fusion(
):
"""Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
x = torch.randint(
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
)
x = embed(x)
y, y_bias = linear_1(x)
z, z_bias = linear_2(y)
......@@ -58,8 +60,8 @@ def warmup_jit_fusion(
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias)
......@@ -69,9 +71,9 @@ def warmup_jit_fusion(
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
......
import warnings
from typing import List
from .extensions import (
CpuAdamArmExtension,
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
)
from .extensions.base_extension import _Extension
__all__ = [
"KernelLoader",
"CPUAdamLoader",
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class LayerNormLoader(KernelLoader):
REGISTRY = [LayerNormCudaExtension]
class MoeLoader(KernelLoader):
REGISTRY = [MoeCudaExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
../../op_builder
\ No newline at end of file
......@@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
......@@ -28,7 +28,7 @@ def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.utils.device import autocast
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
autocast = get_accelerator().autocast
class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment