"vscode:/vscode.git/clone" did not exist on "add95438dfea684bac387fe1cf48e8bdd0c482d8"
Commit 200f01d5 authored by zhanghj2's avatar zhanghj2
Browse files

支持attn_sink

parent 9b54b03c
......@@ -17,7 +17,7 @@ using namespace cute;
namespace smxx::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
__global__ void __launch_bounds__(NUM_THREADS, 1)
flash_fwd_mla_combine_kernel(const CombineParams params) {
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
......@@ -54,20 +54,8 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
);
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
// Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize();
// Prefetch
static_assert(HEAD_DIM_V % (64*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*256); // NOTE We don't use __ldg here since it is incompatible with PDL
}
// __syncthreads();
// Warp #i gathers LseAccum for seq #i
{
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 64);
......@@ -90,36 +78,50 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
float sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
sum_lse = sum_lse + __builtin_amdgcn_exp2f((local_lse[i] - max_lse) * CUDART_L2E_F);
CUTLASS_PRAGMA_UNROLL
for (int offset = 32; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor(sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : logf(sum_lse) + max_lse;
if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E;
gLse(warp_idx) = global_lse;
float o_scale = 0.0f;
if (params.attn_sink != nullptr) {
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
float attn_sink = __ldg(params.attn_sink + q_head_idx);
if (global_lse != INFINITY) {
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse
global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
} else {
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
}
float Attn_sink_exp2 = __builtin_amdgcn_exp2f(attn_sink * CUDART_L2E_F);
float lse_exp2 = __builtin_amdgcn_exp2f(global_lse * CUDART_L2E_F);
o_scale = lse_exp2 / (lse_exp2 + Attn_sink_exp2);
// if (global_lse != INFINITY) {
// // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// // If attn_sink is -inf, this has no effect on global_lse
// global_lse += log2f(1 + __builtin_amdgcn_exp2f(attn_sink*CUDART_L2E_F - global_lse));
// } else {
// // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
// global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
// }
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*64 + lane_idx;
smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
if (split_idx < my_num_splits) {
// printf("local_lse %.2f global_lse = %.2f \n", local_lse[i], global_lse);
smem_buf[warp_idx][split_idx] = __builtin_amdgcn_exp2f((local_lse[i] - global_lse) * CUDART_L2E_F) * o_scale;
}
}
}
__syncthreads();
static_assert(HEAD_DIM_V % (64*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (64*4);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*256); // NOTE We don't use __ldg here since it is incompatible with PDL
}
// Warp #i accumulates activation for seq #i
{
float4 result[ELEMS_PER_THREAD];
......@@ -130,6 +132,10 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
#pragma unroll 1
for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = smem_buf[warp_idx][split];
// if (warp_idx == 2 && threadIdx.x == 128)
// {
// printf("threadIdx.x = %d %.3f %.3f lse_scale = %.2f \n",threadIdx.x, datas[0].x, datas[1].x, lse_scale);
// }
// if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
......@@ -143,7 +149,10 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
}
// }
}
// if (warp_idx == 2)
// {
// printf(" %.3f \n", result[0].x);
// }
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
......@@ -151,6 +160,12 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
float4 data = result[i];
ElementT data_converted[4];
// auto res = __builtin_hcu_cvt_pk_bf16_f32(0, data.x, 0, data.y, 0);
// data_converted[0].storage = res[0];
// data_converted[1].storage = res[1];
// res = __builtin_hcu_cvt_pk_bf16_f32(0, data.z, 0, data.w, 0);
// data_converted[2].storage = res[0];
// data_converted[3].storage = res[1];
data_converted[0] = (ElementT)(data.x);
data_converted[1] = (ElementT)(data.y);
data_converted[2] = (ElementT)(data.z);
......@@ -208,7 +223,7 @@ void run_flash_mla_combine_kernel(CombineParams &params) {
// 1
// };
combine_kernel<<<dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
dim3(NUM_THREADS, 1, 1),
NUM_THREADS,
smem_size,
params.stream>>>(params);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment