Unverified Commit 7cfed5f0 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
parent d7f8db8e
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
namespace cg = cooperative_groups;
const float EPSILON = 1e-8f;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// block reduce max
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
// write shared
__shared__ float s_max[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_max[i] = l_max[i];
}
}
__syncthreads();
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - s_max[i]);
l_sum[i] += val[i][j];
}
}
// block reduce sum
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
// write shared
__shared__ float s_sum[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
}
}
__syncthreads();
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// warp reduce max
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - l_max[i]);
l_sum[i] += val[i][j];
}
}
// warp reduce sum
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <>
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 16;
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 32;
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 64;
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
template <>
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 8;
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 16;
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 32;
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template <typename T, int ITERATIONS>
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
inp += offset;
T grad_reg[ITERATIONS];
T inp_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
inp_reg[i] = inp[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)inp_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
}
}
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream) {
const int warps_per_block = 4;
// rows = batch_size * nhead * from_len
dim3 grid_dim(rows / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (softmax_len <= 32)
ker_attn_softmax_bw<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 64)
ker_attn_softmax_bw<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 128)
ker_attn_softmax_bw<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 256)
ker_attn_softmax_bw<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 384)
ker_attn_softmax_bw<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 512)
ker_attn_softmax_bw<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 768)
ker_attn_softmax_bw<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 1024)
ker_attn_softmax_bw<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 2048)
ker_attn_softmax_bw<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else
throw std::runtime_error(
std::string(
"Special sequence length found in softmax backward, seq_len: ") +
std::to_string(softmax_len));
}
template void launch_attn_softmax_bw<__half>(__half *out_grad,
const __half *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
template void launch_attn_softmax_bw<float>(float *out_grad,
const float *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using namespace cub;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template <typename T>
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
int head_dim);
template <>
__global__ void transform_0213<float>(float *output, const float *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
template <>
__global__ void transform_0213<__half>(__half *output, const __half *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
// [b, s, h] -> [b, nh, s, ad]
template <>
void launch_transform_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
template <>
void launch_transform_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template <typename T>
__global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4);
template <>
__global__ void bias_add_transform_20314<float>(float *output,
const float *input,
const float *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
vres4.x = vqkv4.x + vbias4.x;
vres4.y = vqkv4.y + vbias4.y;
vres4.z = vqkv4.z + vbias4.z;
vres4.w = vqkv4.w + vbias4.w;
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
template <>
__global__ void bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <>
void launch_bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 2;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
template <>
void launch_bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 3;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template <typename T>
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
int seq_len, int trans_count, int nhead,
int head_dim, int num_all) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset >= num_all) {
return;
}
int trans_id, batch_id, head_id, token_id, dim_id;
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
&batch_id, &head_id, &token_id, &dim_id);
// [b, s, tc, nh, ad]
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
seq_len, trans_count, nhead, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
res4[trg_offset] = input4[offset];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <>
void launch_transform4d_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
template <>
void launch_transform4d_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len,
int hidden_dim, int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);
This diff is collapsed.
../../extensions
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Callable
class BaseExtension(ABC):
@abstractmethod
def requires_build(self) -> bool:
pass
@abstractmethod
def build(self) -> None:
pass
@abstractmethod
def load(self) -> Callable:
pass
def fetch(self) -> Callable:
if self.requires_build:
self.build()
return self.load()
from .arm_extension import ArmCPUAdamExtension
from .x86_extension import X86CPUAdamExtension
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class ArmCPUAdamBuilder(ExtensionBuilder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []
This diff is collapsed.
This diff is collapsed.
from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension
from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension
from .npu_sdpa_attn_extension import NpuSdpaAttnExtension
from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension
from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad
__all__ = [
"CudaFlashAttnExtension",
"CudaMemoryEfficentAttnExtension",
"NpuSdpaAttnExtension",
"NpuTriangleAttnExtension",
"HAS_FLASH_ATTN",
"HAS_MEM_EFF_ATTN",
"HAS_NPU_TRIANGLE_ATTENTION",
"Unpad",
"AttnMaskType",
"Repad",
"SeqLenInfo",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment