Commit 75890221 authored by zhanghj2's avatar zhanghj2
Browse files

支持movel1 decode

parent 620f8769
...@@ -176,10 +176,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -176,10 +176,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
int token_index = indices_base[(lane_idx % 16) + warp_idx * 16]; int token_index = indices_base[(lane_idx % 16) + warp_idx * 16];
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size; int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size;
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
const index_t offset_k = block_index * params.stride_kv_block; const index_t offset_k = block_index * k_block_stride;
uint8_t* gK_base = k_ptr + offset_k + rel_idx_in_block*params.stride_kv_row; uint8_t* gK_base;
float scales[NUM_SCALES]; float scales[NUM_SCALES];
if constexpr (MODEL_TYPE == ModelType::V32) { if constexpr (MODEL_TYPE == ModelType::V32) {
gK_base = k_ptr + offset_k + rel_idx_in_block * k_row_stride;
float* scale_ptr = (float*)(gK_base + HEAD_DIM_NOPE); float* scale_ptr = (float*)(gK_base + HEAD_DIM_NOPE);
static_assert(NUM_SCALES == 4); static_assert(NUM_SCALES == 4);
static_assert(HEAD_DIM_NOPE == 512); static_assert(HEAD_DIM_NOPE == 512);
...@@ -198,8 +199,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -198,8 +199,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
} }
} }
} else { } else {
gK_base = k_ptr + offset_k + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*2);;
static_assert(NUM_SCALES == 8); static_assert(NUM_SCALES == 8);
uint8_t* scale_ptr = gK_base + HEAD_DIM_NOPE + HEAD_DIM_ROPE * 2; uint8_t* scale_ptr = k_ptr + offset_k + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*2) + rel_idx_in_block*NUM_SCALES;
if (token_index == -1) if (token_index == -1)
{ {
for (int i = 0; i < NUM_SCALES; i++) for (int i = 0; i < NUM_SCALES; i++)
...@@ -221,21 +223,27 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -221,21 +223,27 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
float as_value; float as_value;
}; };
Fp32 fp32; Fp32 fp32;
for (int i = 0; i < NUM_SCALES; i++) for (int i = 0; i < NUM_SCALES - 1; i++)
{ {
fp32.as_bits = (scale_e8m0.fp8_e8m0[i] << 23); fp32.as_bits = (scale_e8m0.fp8_e8m0[i] << 23);
scales[i] = fp32.as_value; scales[i] = fp32.as_value;
} }
} }
// if (block0() && threadIdx.x < 64)
// {
// printf("tidx = %d, %.3f %.2f %.2f \n",tidx, scales[0], scales[1], scales[2]);
// }
} }
// zhj debug // // zhj debug
// if (head_block_idx == 0 && threadIdx.x < 64) // if (head_block_idx == 0 && threadIdx.x < 64)
// { // {
// printf("tidx = %d, %.2f %.2f %.2f %.2f %d offset_k = %d token_indexrel_idx_in_block = %d params.stride_kv_row = %d %p params.kv = %p \n", tidx, float(scales[0]), float(scales[1]), float(scales[2]), float(scales[3]), // printf("tidx = %d, %.2f %.2f %.2f %.2f %d offset_k = %d rel_idx_in_block = %d params.stride_kv_row = %d %p params.kv = %p \n", tidx, float(scales[0]), float(scales[1]), float(scales[2]), float(scales[3]),
// token_index, // token_index,
// offset_k, // offset_k,
// token_indexrel_idx_in_block, // rel_idx_in_block,
// params.stride_kv_row, // params.stride_kv_row,
// gK_base, // gK_base,
// params.kv // params.kv
...@@ -323,7 +331,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -323,7 +331,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
for (int k_idx = 4; k_idx < 7; k_idx++) for (int k_idx = 4; k_idx < 7; k_idx++)
{ {
for (int j = 0; j < 16; j+=4) { for (int j = 0; j < 16; j+=4) {
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx / 2], j); auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx], j);
tSrK(j, 0, k_idx) = rst0; tSrK(j, 0, k_idx) = rst0;
tSrK(j + 1, 0, k_idx) = rst1; tSrK(j + 1, 0, k_idx) = rst1;
...@@ -346,8 +354,51 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -346,8 +354,51 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
} }
// if (head_block_idx == 0)
// {
// printf("tidx = %d, %.2f %.2f %.2f %.2f \n", tidx, float(acc_s(0)), float(acc_s(1)), float(acc_s(2)), float(acc_s(3)));
// }
bf16_storage bf16_data0;
bf16_storage bf16_data1;
if (token_index == -1)
{
bf16_data0.data_128 = {0};
bf16_data1.data_128 = {0};
}
else
{
bf16_data0.data_128 = *((uint32x4_t*)(gK_base + col_idx * 16 * 2 + HEAD_DIM_NOPE));
bf16_data1.data_128 = *((uint32x4_t*)(gK_base + col_idx * 16 * 2 + 8 * 2 + HEAD_DIM_NOPE));
}
for (int j = 0; j < 8; j++) {
auto rst = cutlass::bfloat16_t::bitcast(bf16_data0.data_array[j]);
tSrK(j, 0, 7) = rst;
}
for (int j = 8; j < 16; j++) {
auto rst = cutlass::bfloat16_t::bitcast(bf16_data1.data_array[j - 8]);
tSrK(j, 0, 7) = rst;
}
constexpr static int k_idx = 7;
// if (block0() && threadIdx.x >= 192)
// {
// printf(" %.4f %.4f %.4f %.4f \n",
// float(tSrK(0, 0, 7)),
// float(tSrK(1, 0, 7)),
// float(tSrK(2, 0, 7)),
// float(tSrK(3, 0, 7))
// );
// }
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
#pragma unroll
for (int j = 0; j < 8; j++) {
// tOsV(j, 0, (k_idx - 4) * 2) = Element(j);
tOsV(j, 0, (k_idx - 4) * 2) = tSrK(j, 0, k_idx);
}
#pragma unroll
for (int j = 8; j < 16; j++) {
tOsV(j - 8, 0, (k_idx - 4) * 2 + 1) = tSrK(j, 0, k_idx);
}
} }
__syncthreads(); __syncthreads();
...@@ -360,7 +411,14 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -360,7 +411,14 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
flash::__ds_read_m32x16_row_col_rrow<1, 1, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<1, 1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 2, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<1, 2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view);
__syncthreads(); __syncthreads();
// if (block0() && threadIdx.x >= 192)
// {
// printf(" %.4f %.4f %.4f %.4f %p %p\n",
// float(tOsVt(0, 3, 0)), float(tOsVt(1, 3, 0)), float( tSrK(8, 0, 7)), float( tSrK(9, 0, 7)),
// &(tOsVt(0, 1, 3)), &(tOsV(0, 0, 7)));
// }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
Fp8_storage data[4]; Fp8_storage data[4];
// __ds_read_m64x16_row_col_rrow<0, 0, 4>(tOsVt, tOrVt_copy_view); // __ds_read_m64x16_row_col_rrow<0, 0, 4>(tOsVt, tOrVt_copy_view);
...@@ -430,11 +488,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -430,11 +488,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor cS = make_identity_tensor(Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}); Tensor cS = make_identity_tensor(Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
Tensor tScS = thr_mma.partition_C(cS); Tensor tScS = thr_mma.partition_C(cS);
auto is_index_valid = [&](int index) -> bool {
if constexpr (MODEL_TYPE == ModelType::V32) {
return indices_base[int(get<1>(tScS(index)))] != -1;
} else {
return indices_base[int(get<1>(tScS(index)))] != -1 && (rel_block_idx*TOPK_BLOCK_SIZE + int(get<1>(tScS(index))) < topk_length);
}
};
for (int i = 0; i < size(acc_s); ++i) { for (int i = 0; i < size(acc_s); ++i) {
int idx = int(get<1>(tScS(i))) + block_idx * TOPK_BLOCK_SIZE; // int idx = indices_base[int(get<1>(tScS(i)))] ;
idx = gIndices[idx] ; if (not is_index_valid(i)) acc_s(i) = -INFINITY;
if (idx == -1) acc_s(i) = -INFINITY;
} }
block_idx == 0 block_idx == 0
? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2) ? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2)
...@@ -655,3 +718,6 @@ void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &param ...@@ -655,3 +718,6 @@ void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &param
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment