"vscode:/vscode.git/clone" did not exist on "c2abba11d1094e76d28067601726a30118bba518"
Commit 50f07abd authored by zhanghj2's avatar zhanghj2
Browse files

处理attn_sink中inf的情况

parent 75890221
...@@ -167,13 +167,13 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -167,13 +167,13 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
} }
[[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length; [[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length;
[[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx; [[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx;
int token_index = indices_base[(lane_idx % 16) + warp_idx * 16];
if constexpr (MODEL_TYPE == ModelType::MODEL1) { if constexpr (MODEL_TYPE == ModelType::MODEL1) {
if (rel_block_idx*TOPK_BLOCK_SIZE >= topk_length) if (rel_block_idx*TOPK_BLOCK_SIZE + (lane_idx % 16) + warp_idx * 16 >= topk_length)
{ {
token_index = -1;
} }
} }
int token_index = indices_base[(lane_idx % 16) + warp_idx * 16];
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size; int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size;
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
const index_t offset_k = block_index * k_block_stride; const index_t offset_k = block_index * k_block_stride;
...@@ -379,7 +379,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -379,7 +379,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
tSrK(j, 0, 7) = rst; tSrK(j, 0, 7) = rst;
} }
constexpr static int k_idx = 7; constexpr static int k_idx = 7;
// if (block0() && threadIdx.x >= 192) // if (block0())
// { // {
// printf(" %.4f %.4f %.4f %.4f \n", // printf(" %.4f %.4f %.4f %.4f \n",
// float(tSrK(0, 0, 7)), // float(tSrK(0, 0, 7)),
...@@ -499,10 +499,13 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -499,10 +499,13 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// int idx = indices_base[int(get<1>(tScS(i)))] ; // int idx = indices_base[int(get<1>(tScS(i)))] ;
if (not is_index_valid(i)) acc_s(i) = -INFINITY; if (not is_index_valid(i)) acc_s(i) = -INFINITY;
} }
block_idx == 0 block_idx == args.start_block_idx
? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2) ? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2)
: softmax.template softmax_rescale_o_prefill</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2); : softmax.template softmax_rescale_o_prefill</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2);
// if (head_block_idx == 0 && batch_idx == 0)
// {
// printf("tidx = %d, %.2f %.2f %.2f %.2f \n", tidx, float(acc_s(0)), float(acc_s(1)), float(acc_s(2)), float(acc_s(3)));
// }
Tensor rP = flash::convert_type<Element>(acc_s); Tensor rP = flash::convert_type<Element>(acc_s);
Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP);
...@@ -543,7 +546,10 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -543,7 +546,10 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
process_one_block(block_idx, IsExtraBlock{}); process_one_block(block_idx, IsExtraBlock{});
} }
} }
// if (head_block_idx == 0 && threadIdx.x < 64 && batch_idx == 0)
// {
// printf(" %.4f %.4f \n", acc_o(0), acc_o(1));
// }
if (args.is_no_split) { if (args.is_no_split) {
int start_head_idx = head_block_idx*BLOCK_M; int start_head_idx = head_block_idx*BLOCK_M;
Tensor lse = softmax.template normalize_softmax_lse<false>(acc_o, sRow_sum_reduce_buffer, params.sm_scale); Tensor lse = softmax.template normalize_softmax_lse<false>(acc_o, sRow_sum_reduce_buffer, params.sm_scale);
...@@ -553,13 +559,27 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -553,13 +559,27 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
make_stride(params.stride_o_h_q, _1{})); make_stride(params.stride_o_h_q, _1{}));
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + start_head_idx + lane_idx % 16); float rAttn_sink = __ldg((float*)params.attn_sink + start_head_idx + lane_idx % 16);
float lse_exp2 = __builtin_amdgcn_exp2f(lse[lane_idx % 16] * CUDART_L2E_F); if (flash::is_positive_infinity(rAttn_sink))
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
float o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
for (int i = 0; i < size(acc_o); i++)
{ {
acc_o(i) *= o_scale; for (int i = 0; i < size(acc_o); i++)
{
acc_o(i) = 0.0f;
}
} }
else
{
if (!flash::is_positive_infinity(lse(0)))
{
float lse_exp2 = __builtin_amdgcn_exp2f(lse[0] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
float o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
for (int i = 0; i < size(acc_o); i++)
{
acc_o(i) *= o_scale;
}
}
}
} }
float* gSoftmaxLse = (float*)params.lse + batch_idx * params.stride_lse_b + start_head_idx + s_q_idx * params.stride_lse_s_q; // (BLOCK_M) : (1) float* gSoftmaxLse = (float*)params.lse + batch_idx * params.stride_lse_b + start_head_idx + s_q_idx * params.stride_lse_s_q; // (BLOCK_M) : (1)
......
...@@ -90,9 +90,17 @@ flash_fwd_mla_combine_kernel(const CombineParams params) { ...@@ -90,9 +90,17 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
float attn_sink = __ldg(params.attn_sink + q_head_idx); float attn_sink = __ldg(params.attn_sink + q_head_idx);
float Attn_sink_exp2 = __builtin_amdgcn_exp2f(attn_sink * CUDART_L2E_F); if (flash::is_positive_infinity(attn_sink))
float lse_exp2 = __builtin_amdgcn_exp2f(global_lse * CUDART_L2E_F); {
o_scale = lse_exp2 / (lse_exp2 + Attn_sink_exp2); o_scale = 0.0f;
}
else
{
float Attn_sink_exp2 = __builtin_amdgcn_exp2f(attn_sink * CUDART_L2E_F);
float lse_exp2 = __builtin_amdgcn_exp2f(global_lse * CUDART_L2E_F);
o_scale = lse_exp2 / (lse_exp2 + Attn_sink_exp2);
}
// if (global_lse != INFINITY) { // if (global_lse != INFINITY) {
// // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) // // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment