Commit 4acf0e01 authored by aiss's avatar aiss
Browse files

delete hip file

parent 7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
#include "custom_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
cudaStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
dequantize_kernel<<<grid_dims, block_dims, 0, stream>>>(
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
#include "custom_cuda_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, cudaStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
cudaStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
cudaStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
cudaStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
moe_res_matmul<<<grid_dim, block_dim, 0, stream>>>(
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, hipStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, hipStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
#define MAX_OUT_TOKES 10
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 3) seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);
int heads = 1;
if (len > 3) heads = attn_scores_c.size(1);
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
1.0,
Context::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t max_seq_len,
size_t batch_size,
size_t head_size = 128)
{
size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len);
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
}
template <typename T>
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
{
auto options = at::TensorOptions()
.dtype(Q.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
if (!workspace) {
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0));
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_T,
m,
n,
k,
&alpha,
&gemm_beta,
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return O;
}
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
int& seq_len,
int& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
float alpha = norm_factor;
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont.data_ptr(),
(T*)query_cont.data_ptr(),
(T*)attn_score.data_ptr(),
CUBLAS_OP_N,
CUBLAS_OP_N,
soft_len * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont.data_ptr(),
(T*)attn_score.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_OP_N,
CUBLAS_OP_N,
soft_len * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
at::Tensor& prev_key,
at::Tensor& new_key,
at::Tensor& attn_mask,
at::Tensor& prev_value,
at::Tensor& new_value,
int heads,
float norm_factor,
bool merging,
bool triangular,
bool local_attention,
int window_size,
bool no_masking)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
auto prev_value_cont = prev_value.contiguous();
int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int bsz = query_cont.size(0);
int seq_len = query_cont.size(1);
int soft_len = prev_value.size(1);
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output =
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
seq_len,
soft_len,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto residual_cont = residual.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon)
{
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
launch_layer_norm((T*)inp_norm.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input_cont.size(2),
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
int bsz = input.size(0) * input.size(1);
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm =
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
return {output, inp_norm};
}
template <typename T>
void quantized_gemm(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int merge_count)
{
int bsz = input.size(0) * input.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(1),
weight.size(0),
groups,
merge_count,
Context::Instance().GetCurrentStream());
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool add_bias)
{
int bsz = input.size(0) * input.size(1);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layernorm<T>(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& q_scale,
int groups)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return output;
}
template <typename T>
at::Tensor ds_vector_matmul_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
int groups,
int merge_count)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, merge_count);
return output;
}
template <typename T>
void mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
launch_residual_layer_norm((T*)inp_norm.data_ptr(),
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
(T*)input_bias.data_ptr(),
(T*)gamma.data_ptr(),
(T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
preLayerNorm,
mlp_after_attn,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
template <typename T>
at::Tensor ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
return output;
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& weight_out,
const float epsilon,
bool preLayerNorm,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
void residual_add_bias(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
if (mlp_after_attn)
launch_bias_residual((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
else if (mlp_after_attn)
launch_bias_residual((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half,
bool rotate_every_two)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
return {query_cont, key_cont};
}
template <typename T>
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
(float*)output.data_ptr(),
M,
N,
at::cuda::getCurrentCUDAStream());
} else {
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
(__half*)coef.data_ptr(),
(__half*)output.data_ptr(),
M,
N,
at::cuda::getCurrentCUDAStream());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
#define MAX_OUT_TOKES 10
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 3) seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);
int heads = 1;
if (len > 3) heads = attn_scores_c.size(1);
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
1.0,
Context::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t max_seq_len,
size_t batch_size,
size_t head_size = 128)
{
size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len);
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
}
template <typename T>
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
{
auto options = at::TensorOptions()
.dtype(Q.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
if (!workspace) {
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0));
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_transpose,
m,
n,
k,
&alpha,
&gemm_beta,
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return O;
}
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
int& seq_len,
int& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
float alpha = norm_factor;
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont.data_ptr(),
(T*)query_cont.data_ptr(),
(T*)attn_score.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont.data_ptr(),
(T*)attn_score.data_ptr(),
(T*)output.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
at::Tensor& prev_key,
at::Tensor& new_key,
at::Tensor& attn_mask,
at::Tensor& prev_value,
at::Tensor& new_value,
int heads,
float norm_factor,
bool merging,
bool triangular,
bool local_attention,
int window_size,
bool no_masking)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
auto prev_value_cont = prev_value.contiguous();
int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int bsz = query_cont.size(0);
int seq_len = query_cont.size(1);
int soft_len = prev_value.size(1);
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output =
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
seq_len,
soft_len,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto residual_cont = residual.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon)
{
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
launch_layer_norm((T*)inp_norm.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input_cont.size(2),
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
// hipEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
int bsz = input.size(0) * input.size(1);
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm =
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
return {output, inp_norm};
}
template <typename T>
void quantized_gemm(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int merge_count)
{
int bsz = input.size(0) * input.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(1),
weight.size(0),
groups,
merge_count,
Context::Instance().GetCurrentStream());
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool add_bias)
{
int bsz = input.size(0) * input.size(1);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layernorm<T>(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& q_scale,
int groups)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return output;
}
template <typename T>
at::Tensor ds_vector_matmul_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
int groups,
int merge_count)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, merge_count);
return output;
}
template <typename T>
void mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
launch_residual_layer_norm((T*)inp_norm.data_ptr(),
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
(T*)input_bias.data_ptr(),
(T*)gamma.data_ptr(),
(T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
preLayerNorm,
mlp_after_attn,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
template <typename T>
at::Tensor ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
return output;
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& weight_out,
const float epsilon,
bool preLayerNorm,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
void residual_add_bias(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// hipStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
if (mlp_after_attn)
launch_bias_residual((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
else if (mlp_after_attn)
launch_bias_residual((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half,
bool rotate_every_two)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
return {query_cont, key_cont};
}
template <typename T>
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
(float*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
} else {
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
(__half*)coef.data_ptr(),
(__half*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
}
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
cudaError_t err = cudaGetLastError();
if (err == cudaSuccess) return;
std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
return std::max(
std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comp_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comm_event, (cudaEventDisableTiming | cudaEventBlockingSync));
}
virtual ~Context()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
cudaEventDestroy(_comp1_event);
cudaEventDestroy(_comp2_event);
cudaEventDestroy(_comp_event);
cudaEventDestroy(_comm_event);
}
static Context& Instance()
{
static Context _ctx;
return _ctx;
}
void GenWorkSpace(size_t size)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}
_workSpaceSize = size;
}
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
inline unsigned new_token(unsigned layer_id)
{
if (layer_id == 0) _token_length++;
return _token_length;
}
inline void reset_tokens(unsigned initial_tokens = 0)
{
_num_tokens = initial_tokens;
} //_token_length = 0; }
inline unsigned current_tokens() const { return _num_tokens; }
inline void advance_tokens() { _num_tokens++; }
curandGenerator_t& GetRandGenerator() { return _gen; }
cudaStream_t GetCommStream(bool async_op = false)
{
if (!_comm_stream)
_comm_stream = async_op ? at::cuda::getStreamFromPool(true)
: at::cuda::getCurrentCUDAStream();
return _comm_stream;
}
cudaStream_t GetCurrentStream(bool other_stream = false)
{
// get current pytorch stream.
if (other_stream) {
if (!_stream) _stream = at::cuda::getStreamFromPool(true);
return _stream;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
inline void SynchComp()
{
cudaEventRecord(_comp_event, _comp_stream);
cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
}
inline void SynchComm()
{
cudaEventRecord(_comm_event, _comm_stream);
cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
}
private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
cudaEvent_t _comp_event;
cudaEvent_t _comm_event;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
cudaEvent_t _comp1_event;
cudaEvent_t _comp2_event;
cudaStream_t _stream;
unsigned _token_length;
unsigned _num_tokens;
std::vector<std::array<int, 3>> _gemm_algos;
cudaStream_t _comp_stream;
cudaStream_t _comm_stream;
std::unordered_map<int, int> _world_sizes;
};
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand/hiprand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
return std::max(
std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
{
hiprandCreateGenerator(&_gen, HIPRAND_RNG_PSEUDO_DEFAULT);
hiprandSetPseudoRandomGeneratorSeed(_gen, 123);
if (rocblas_create_handle(&_cublasHandle) != rocblas_status_success) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
hipEventCreate(&_comp1_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comp2_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comp_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comm_event, (hipEventDisableTiming | hipEventBlockingSync));
}
virtual ~Context()
{
rocblas_destroy_handle(_cublasHandle);
hipFree(_workspace);
hipEventDestroy(_comp1_event);
hipEventDestroy(_comp2_event);
hipEventDestroy(_comp_event);
hipEventDestroy(_comm_event);
}
static Context& Instance()
{
static Context _ctx;
return _ctx;
}
void GenWorkSpace(size_t size)
{
if (!_workspace) {
assert(_workspace == nullptr);
hipMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
hipFree(_workspace);
hipMalloc(&_workspace, size);
}
_workSpaceSize = size;
}
hipEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
inline unsigned new_token(unsigned layer_id)
{
if (layer_id == 0) _token_length++;
return _token_length;
}
inline void reset_tokens(unsigned initial_tokens = 0)
{
_num_tokens = initial_tokens;
} //_token_length = 0; }
inline unsigned current_tokens() const { return _num_tokens; }
inline void advance_tokens() { _num_tokens++; }
hiprandGenerator_t& GetRandGenerator() { return _gen; }
hipStream_t GetCommStream(bool async_op = false)
{
if (!_comm_stream)
_comm_stream = async_op ? at::hip::getStreamFromPoolMasqueradingAsCUDA(true)
: at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
return _comm_stream;
}
hipStream_t GetCurrentStream(bool other_stream = false)
{
// get current pytorch stream.
if (other_stream) {
if (!_stream) _stream = at::hip::getStreamFromPoolMasqueradingAsCUDA(true);
return _stream;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
return stream;
}
rocblas_handle GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
inline void SynchComp()
{
hipEventRecord(_comp_event, _comp_stream);
hipStreamWaitEvent(_comm_stream, _comp_event, 0);
}
inline void SynchComm()
{
hipEventRecord(_comm_event, _comm_stream);
hipStreamWaitEvent(_comp_stream, _comm_event, 0);
}
private:
hiprandGenerator_t _gen;
rocblas_handle _cublasHandle;
hipEvent_t _comp_event;
hipEvent_t _comm_event;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
hipEvent_t _comp1_event;
hipEvent_t _comp2_event;
hipStream_t _stream;
unsigned _token_length;
unsigned _num_tokens;
std::vector<std::array<int, 3>> _gemm_algos;
hipStream_t _comp_stream;
hipStream_t _comm_stream;
std::unordered_map<int, int> _world_sizes;
};
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_32F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_32F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
C,
CUDA_R_32F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_16F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
(void*)C,
CUDA_R_16F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_32F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_16F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <assert.h>
#include <rocblas.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <mma.h>
#include <stdio.h>
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
if (status != rocblas_status_success) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#pragma once
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define SMs 80
#define MAX_REGISTERS 256
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream);
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream);
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream);
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
cudaStream_t stream);
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
cudaStream_t stream);
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int head_size,
int mp_size,
cudaStream_t stream);
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream);
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#pragma once
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define SMs 80
#define MAX_REGISTERS 256
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream);
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, hipStream_t stream);
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
hipStream_t stream);
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream);
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int head_size,
int mp_size,
hipStream_t stream);
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream);
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (threadIdx.x == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
// LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
// out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 xu[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
__half mean_h = means[row];
__half2 mean_reg = __halves2half2(mean_h, mean_h);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
xu[iterations] = (vals_hat_h[high_index] - mean_reg);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<float><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
LayerNormBackward2_fused_add<<<grid_dim2, block_dim2, 0, stream[1]>>>(
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (threadIdx.x == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
//hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
// out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 xu[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
__half mean_h = means[row];
__half2 mean_reg = __halves2half2(mean_h, mean_h);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
xu[iterations] = (vals_hat_h[high_index] - mean_reg);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
namespace cg = cooperative_groups;
dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
{
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
// Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// The batch size is typically relatively small, while the sequence length could potentially be
// arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
unsigned x = heads * sequence_length / block_compute_size;
unsigned y = batch_size;
return {x, y};
}
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = std::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
cudaStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
attn_softmax<2, (threads / 2), 2>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
attn_softmax<4, (threads / 4), 4>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
attn_softmax<8, (threads / 8), 8>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
attn_softmax<16, (threads / 16), 16>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
attn_softmax<32, (threads / 32), 32>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
attn_softmax<32, (threads / 64), 64>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
cudaStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
attn_softmax<2, (threads / 2), 2>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
attn_softmax<4, (threads / 4), 4>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
attn_softmax<8, (threads / 8), 8>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
attn_softmax<16, (threads / 16), 16>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
attn_softmax<32, (threads / 32), 32>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
attn_softmax<32, (threads / 64), 64>
<<<grid_dim, block_dim, 0, stream>>>(vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
attn_softmax<32, 1, 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
cudaStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment