Commit 75890221 authored by zhanghj2's avatar zhanghj2
Browse files

支持movel1 decode

parent 620f8769
......@@ -176,10 +176,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
int token_index = indices_base[(lane_idx % 16) + warp_idx * 16];
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size;
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
const index_t offset_k = block_index * params.stride_kv_block;
uint8_t* gK_base = k_ptr + offset_k + rel_idx_in_block*params.stride_kv_row;
const index_t offset_k = block_index * k_block_stride;
uint8_t* gK_base;
float scales[NUM_SCALES];
if constexpr (MODEL_TYPE == ModelType::V32) {
gK_base = k_ptr + offset_k + rel_idx_in_block * k_row_stride;
float* scale_ptr = (float*)(gK_base + HEAD_DIM_NOPE);
static_assert(NUM_SCALES == 4);
static_assert(HEAD_DIM_NOPE == 512);
......@@ -198,8 +199,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
}
} else {
gK_base = k_ptr + offset_k + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*2);;
static_assert(NUM_SCALES == 8);
uint8_t* scale_ptr = gK_base + HEAD_DIM_NOPE + HEAD_DIM_ROPE * 2;
uint8_t* scale_ptr = k_ptr + offset_k + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*2) + rel_idx_in_block*NUM_SCALES;
if (token_index == -1)
{
for (int i = 0; i < NUM_SCALES; i++)
......@@ -221,21 +223,27 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
float as_value;
};
Fp32 fp32;
for (int i = 0; i < NUM_SCALES; i++)
for (int i = 0; i < NUM_SCALES - 1; i++)
{
fp32.as_bits = (scale_e8m0.fp8_e8m0[i] << 23);
scales[i] = fp32.as_value;
}
}
// if (block0() && threadIdx.x < 64)
// {
// printf("tidx = %d, %.3f %.2f %.2f \n",tidx, scales[0], scales[1], scales[2]);
// }
}
// zhj debug
// // zhj debug
// if (head_block_idx == 0 && threadIdx.x < 64)
// {
// printf("tidx = %d, %.2f %.2f %.2f %.2f %d offset_k = %d token_indexrel_idx_in_block = %d params.stride_kv_row = %d %p params.kv = %p \n", tidx, float(scales[0]), float(scales[1]), float(scales[2]), float(scales[3]),
// printf("tidx = %d, %.2f %.2f %.2f %.2f %d offset_k = %d rel_idx_in_block = %d params.stride_kv_row = %d %p params.kv = %p \n", tidx, float(scales[0]), float(scales[1]), float(scales[2]), float(scales[3]),
// token_index,
// offset_k,
// token_indexrel_idx_in_block,
// rel_idx_in_block,
// params.stride_kv_row,
// gK_base,
// params.kv
......@@ -323,7 +331,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
for (int k_idx = 4; k_idx < 7; k_idx++)
{
for (int j = 0; j < 16; j+=4) {
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx / 2], j);
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx], j);
tSrK(j, 0, k_idx) = rst0;
tSrK(j + 1, 0, k_idx) = rst1;
......@@ -346,8 +354,51 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
}
// if (head_block_idx == 0)
// {
// printf("tidx = %d, %.2f %.2f %.2f %.2f \n", tidx, float(acc_s(0)), float(acc_s(1)), float(acc_s(2)), float(acc_s(3)));
// }
bf16_storage bf16_data0;
bf16_storage bf16_data1;
if (token_index == -1)
{
bf16_data0.data_128 = {0};
bf16_data1.data_128 = {0};
}
else
{
bf16_data0.data_128 = *((uint32x4_t*)(gK_base + col_idx * 16 * 2 + HEAD_DIM_NOPE));
bf16_data1.data_128 = *((uint32x4_t*)(gK_base + col_idx * 16 * 2 + 8 * 2 + HEAD_DIM_NOPE));
}
for (int j = 0; j < 8; j++) {
auto rst = cutlass::bfloat16_t::bitcast(bf16_data0.data_array[j]);
tSrK(j, 0, 7) = rst;
}
for (int j = 8; j < 16; j++) {
auto rst = cutlass::bfloat16_t::bitcast(bf16_data1.data_array[j - 8]);
tSrK(j, 0, 7) = rst;
}
constexpr static int k_idx = 7;
// if (block0() && threadIdx.x >= 192)
// {
// printf(" %.4f %.4f %.4f %.4f \n",
// float(tSrK(0, 0, 7)),
// float(tSrK(1, 0, 7)),
// float(tSrK(2, 0, 7)),
// float(tSrK(3, 0, 7))
// );
// }
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
#pragma unroll
for (int j = 0; j < 8; j++) {
// tOsV(j, 0, (k_idx - 4) * 2) = Element(j);
tOsV(j, 0, (k_idx - 4) * 2) = tSrK(j, 0, k_idx);
}
#pragma unroll
for (int j = 8; j < 16; j++) {
tOsV(j - 8, 0, (k_idx - 4) * 2 + 1) = tSrK(j, 0, k_idx);
}
}
__syncthreads();
......@@ -360,7 +411,14 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
flash::__ds_read_m32x16_row_col_rrow<1, 1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view);
__syncthreads();
// if (block0() && threadIdx.x >= 192)
// {
// printf(" %.4f %.4f %.4f %.4f %p %p\n",
// float(tOsVt(0, 3, 0)), float(tOsVt(1, 3, 0)), float( tSrK(8, 0, 7)), float( tSrK(9, 0, 7)),
// &(tOsVt(0, 1, 3)), &(tOsV(0, 0, 7)));
// }
__builtin_amdgcn_sched_barrier(0);
Fp8_storage data[4];
// __ds_read_m64x16_row_col_rrow<0, 0, 4>(tOsVt, tOrVt_copy_view);
......@@ -430,11 +488,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor cS = make_identity_tensor(Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
Tensor tScS = thr_mma.partition_C(cS);
auto is_index_valid = [&](int index) -> bool {
if constexpr (MODEL_TYPE == ModelType::V32) {
return indices_base[int(get<1>(tScS(index)))] != -1;
} else {
return indices_base[int(get<1>(tScS(index)))] != -1 && (rel_block_idx*TOPK_BLOCK_SIZE + int(get<1>(tScS(index))) < topk_length);
}
};
for (int i = 0; i < size(acc_s); ++i) {
int idx = int(get<1>(tScS(i))) + block_idx * TOPK_BLOCK_SIZE;
idx = gIndices[idx] ;
if (idx == -1) acc_s(i) = -INFINITY;
// int idx = indices_base[int(get<1>(tScS(i)))] ;
if (not is_index_valid(i)) acc_s(i) = -INFINITY;
}
block_idx == 0
? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2)
......@@ -655,3 +718,6 @@ void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &param
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment