Commit 9b54b03c authored by zhanghj2's avatar zhanghj2
Browse files

支持attn_sink

parent 5813dcc1
......@@ -19,6 +19,7 @@
using namespace cute;
namespace sm90::decode::sparse_fp8 {
#define CUDART_L2E_F 1.442695041F
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
......@@ -403,6 +404,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_h_q, _1{}));
if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + start_head_idx + lane_idx % 16);
float lse_exp2 = __builtin_amdgcn_exp2f(lse[lane_idx % 16] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
float o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
for (int i = 0; i < size(acc_o); i++)
{
acc_o(i) *= o_scale;
}
}
float* gSoftmaxLse = (float*)params.lse + batch_idx * params.stride_lse_b + start_head_idx + s_q_idx * params.stride_lse_s_q; // (BLOCK_M) : (1)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment