Commit 620f8769 authored by zhanghj2's avatar zhanghj2
Browse files

修改支持modelv1,v32部分通过,model1未修改完

parent 6fb681fc
...@@ -122,6 +122,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -122,6 +122,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
typedef unsigned short int __hip_fp8x2_storage_t; typedef unsigned short int __hip_fp8x2_storage_t;
typedef unsigned char __hip_fp8_storage_t; typedef unsigned char __hip_fp8_storage_t;
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __fp16 __fp16x4_t __attribute__((ext_vector_type(4)));
union Fp8_storage{ union Fp8_storage{
__fp16x8_t data_128; __fp16x8_t data_128;
...@@ -137,6 +138,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -137,6 +138,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{}); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{});
clear(acc_o); clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax; flash::Softmax<size<1>(acc_o)> softmax;
MainloopArgs args = get_cur_req_info(batch_idx);
struct IsOrigBlock {}; struct IsOrigBlock {};
struct IsExtraBlock {}; struct IsExtraBlock {};
...@@ -145,28 +147,88 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -145,28 +147,88 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}); Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
clear(acc_s); clear(acc_s);
int col_idx = lane_idx / 16; int col_idx = lane_idx / 16;
int token_index = gIndices[block_idx * TOPK_BLOCK_SIZE + (lane_idx % 16) + warp_idx * 16];
int page_block_size = params.page_block_size; int* indices_base;
int page_block_size;
int64_t k_block_stride, k_row_stride;
uint8_t* k_ptr;
if constexpr (!IS_EXTRA_BLOCK) {
indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE;
page_block_size = params.page_block_size;
k_block_stride = params.stride_kv_block;
k_row_stride = params.stride_kv_row;
k_ptr = (uint8_t*)params.kv;
} else {
indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE;
page_block_size = params.extra_page_block_size;
k_block_stride = params.stride_extra_kv_block;
k_row_stride = params.stride_extra_kv_row;
k_ptr = (uint8_t*)params.extra_kv;
}
[[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length;
[[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx;
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
if (rel_block_idx*TOPK_BLOCK_SIZE >= topk_length)
{
}
}
int token_index = indices_base[(lane_idx % 16) + warp_idx * 16];
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size; int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance const int token_indexrel_idx_in_block = (token_index + page_block_size) % page_block_size;
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
const index_t offset_k = block_index * params.stride_kv_block; const index_t offset_k = block_index * params.stride_kv_block;
uint8_t* gK_base = (uint8_t*)params.kv + offset_k + rel_idx_in_block*params.stride_kv_row; uint8_t* gK_base = k_ptr + offset_k + rel_idx_in_block*params.stride_kv_row;
float* scale_ptr = (float*)(gK_base + 512); float scales[NUM_SCALES];
float scales[4]; if constexpr (MODEL_TYPE == ModelType::V32) {
if (token_index == -1) float* scale_ptr = (float*)(gK_base + HEAD_DIM_NOPE);
{ static_assert(NUM_SCALES == 4);
scales[0] = 0.0f; static_assert(HEAD_DIM_NOPE == 512);
scales[1] = 0.0f; if (token_index == -1)
scales[2] = 0.0f; {
scales[3] = 0.0f; for (int i = 0; i < NUM_SCALES; i++)
} {
else scales[i] = 0.0f;
{ }
for (int i = 0; i < 4; i++) }
else
{ {
scales[i] = scale_ptr[i]; for (int i = 0; i < NUM_SCALES; i++)
{
scales[i] = scale_ptr[i];
}
}
} else {
static_assert(NUM_SCALES == 8);
uint8_t* scale_ptr = gK_base + HEAD_DIM_NOPE + HEAD_DIM_ROPE * 2;
if (token_index == -1)
{
for (int i = 0; i < NUM_SCALES; i++)
{
scales[i] = 0.0f;
}
}
else
{
union Scale_e8m0
{
__fp16x4_t tmp;
__hip_fp8_storage_t fp8_e8m0[NUM_SCALES];
};
Scale_e8m0 scale_e8m0;
scale_e8m0.tmp = *(__fp16x4_t*)(scale_ptr);
union Fp32{
uint32_t as_bits;
float as_value;
};
Fp32 fp32;
for (int i = 0; i < NUM_SCALES; i++)
{
fp32.as_bits = (scale_e8m0.fp8_e8m0[i] << 23);
scales[i] = fp32.as_value;
}
} }
} }
// zhj debug // zhj debug
// if (head_block_idx == 0 && threadIdx.x < 64) // if (head_block_idx == 0 && threadIdx.x < 64)
// { // {
...@@ -179,73 +241,115 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -179,73 +241,115 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// params.kv // params.kv
// ); // );
// } // }
Fp8_storage data[4]; auto dequant_to_bf16 = [&](const Fp8_storage& data0, const float& kv_scale, int idx) -> std::tuple<Element, Element, Element, Element> {
for (int k_idx = 4; k_idx < 8; k_idx++) #if defined(__gfx938__)
auto res1 = __builtin_amdgcn_cvt_pk_f32_fp8(data0.fp8_array[idx/4], false);
auto res2 = __builtin_amdgcn_cvt_pk_f32_fp8(data0.fp8_array[idx/4], true);
auto f1 = res1[0];
auto f2 = res1[1];
auto f3 = res2[0];
auto f4 = res2[1];
#else
const auto fp8x2_low = *reinterpret_cast<const __hip_fp8x2_storage_t*>(&data0.fp8_array[idx / 4]);
const auto fp8x2_high = *(reinterpret_cast<const __hip_fp8x2_storage_t*>(&(data0.fp8_array[idx / 4])) + 1);
auto f1 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
#endif
f1 *= kv_scale;
f2 *= kv_scale;
f3 *= kv_scale;
f4 *= kv_scale;
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
return {rst0, rst1, rst2, rst3};
};
if constexpr (MODEL_TYPE == ModelType::V32)
{ {
if (token_index == -1) { Fp8_storage data[4];
data[k_idx - 4].data_128 = {0}; for (int k_idx = 4; k_idx < 8; k_idx++)
} else { {
data[k_idx - 4].data_128 = *((__fp16x8_t*)(gK_base + col_idx * 16 + k_idx * 64)); if (token_index == -1) {
data[k_idx - 4].data_128 = {0};
} else {
data[k_idx - 4].data_128 = *((__fp16x8_t*)(gK_base + col_idx * 16 + k_idx * 64));
}
} }
} for (int k_idx = 4; k_idx < 8; k_idx++)
{
for (int j = 0; j < 16; j+=4) {
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx / 2], j);
tSrK(j, 0, k_idx) = rst0;
tSrK(j + 1, 0, k_idx) = rst1;
tSrK(j + 2, 0, k_idx) = rst2;
tSrK(j + 3, 0, k_idx) = rst3;
for (int k_idx = 4; k_idx < 8; k_idx++) }
{ // cute::copy(smem_tiled_copy_K, tSrK(_, _, k_idx), tOsV(_, _, k_idx % 4));
for (int j = 0; j < 16; j+=4) { // __builtin_amdgcn_sched_barrier(0);
#if defined(__gfx938__)
auto res1 = __builtin_amdgcn_cvt_pk_f32_fp8(data[k_idx - 4].fp8_array[j/4], false);
auto res2 = __builtin_amdgcn_cvt_pk_f32_fp8(data[k_idx - 4].fp8_array[j/4], true);
auto f1 = res1[0];
auto f2 = res1[1];
auto f3 = res2[0];
auto f4 = res2[1];
#else
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx - 4].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx - 4].fp8_array[j / 4])) + 1);
auto f1 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
#endif
f1 *= scales[k_idx / 2];
f2 *= scales[k_idx / 2];
f3 *= scales[k_idx / 2];
f4 *= scales[k_idx / 2];
// if (block0)
// {
// printf(" tidx = %d %.4f %.4f %.4f %.4f \n", threadIdx.x, f1, f2, f3, f4);
// }
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_; #pragma unroll
auto rst0 = convert_(f1); for (int j = 0; j < 8; j++) {
auto rst1 = convert_(f2); tOsV(j, 0, (k_idx - 4) * 2) = tSrK(j, 0, k_idx);
auto rst2 = convert_(f3); }
auto rst3 = convert_(f4); #pragma unroll
tSrK(j, 0, k_idx) = rst0; for (int j = 8; j < 16; j++) {
tSrK(j + 1, 0, k_idx) = rst1; tOsV(j - 8, 0, (k_idx - 4) * 2 + 1) = tSrK(j, 0, k_idx);
tSrK(j + 2, 0, k_idx) = rst2; }
tSrK(j + 3, 0, k_idx) = rst3; // __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
} }
// cute::copy(smem_tiled_copy_K, tSrK(_, _, k_idx), tOsV(_, _, k_idx % 4)); }
// __builtin_amdgcn_sched_barrier(0); else
{
Fp8_storage data[3];
for (int k_idx = 4; k_idx < 7; k_idx++)
{
if (token_index == -1) {
data[k_idx - 4].data_128 = {0};
} else {
data[k_idx - 4].data_128 = *((__fp16x8_t*)(gK_base + col_idx * 16 + k_idx * 64));
}
}
for (int k_idx = 4; k_idx < 7; k_idx++)
{
for (int j = 0; j < 16; j+=4) {
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx - 4], scales[k_idx / 2], j);
tSrK(j, 0, k_idx) = rst0;
tSrK(j + 1, 0, k_idx) = rst1;
tSrK(j + 2, 0, k_idx) = rst2;
tSrK(j + 3, 0, k_idx) = rst3;
#pragma unroll }
for (int j = 0; j < 8; j++) { // cute::copy(smem_tiled_copy_K, tSrK(_, _, k_idx), tOsV(_, _, k_idx % 4));
tOsV(j, 0, (k_idx - 4) * 2) = tSrK(j, 0, k_idx); // __builtin_amdgcn_sched_barrier(0);
#pragma unroll
for (int j = 0; j < 8; j++) {
tOsV(j, 0, (k_idx - 4) * 2) = tSrK(j, 0, k_idx);
}
#pragma unroll
for (int j = 8; j < 16; j++) {
tOsV(j - 8, 0, (k_idx - 4) * 2 + 1) = tSrK(j, 0, k_idx);
}
// __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
} }
#pragma unroll
for (int j = 8; j < 16; j++) {
tOsV(j - 8, 0, (k_idx - 4) * 2 + 1) = tSrK(j, 0, k_idx);
}
// __builtin_amdgcn_sched_barrier(0);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
} }
__syncthreads(); __syncthreads();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
flash::__ds_read_m32x16_row_col_rrow<0, 0, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<0, 0, 2>(tOsVt, tOrVt_copy_view);
...@@ -258,7 +362,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -258,7 +362,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view);
__syncthreads(); __syncthreads();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
Fp8_storage data[4];
// __ds_read_m64x16_row_col_rrow<0, 0, 4>(tOsVt, tOrVt_copy_view); // __ds_read_m64x16_row_col_rrow<0, 0, 4>(tOsVt, tOrVt_copy_view);
for (int k_idx = 0; k_idx < 4; k_idx++) for (int k_idx = 0; k_idx < 4; k_idx++)
{ {
...@@ -271,40 +375,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -271,40 +375,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
for (int k_idx = 0; k_idx < 4; k_idx++) for (int k_idx = 0; k_idx < 4; k_idx++)
{ {
for (int j = 0; j < 16; j+=4) { for (int j = 0; j < 16; j+=4) {
auto [rst0, rst1, rst2, rst3] = dequant_to_bf16(data[k_idx], scales[MODEL_TYPE == ModelType::V32 ? k_idx / 2 : k_idx], j);
#if defined(__gfx938__)
auto res1 = __builtin_amdgcn_cvt_pk_f32_fp8(data[k_idx].fp8_array[j/4], false);
auto res2 = __builtin_amdgcn_cvt_pk_f32_fp8(data[k_idx].fp8_array[j/4], true);
auto f1 = res1[0];
auto f2 = res1[1];
auto f3 = res2[0];
auto f4 = res2[1];
#else
auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&data[k_idx].fp8_array[j / 4]);
auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&(data[k_idx].fp8_array[j / 4])) + 1);
auto f1 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8));
auto f2 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_low >> 8)));
auto f3 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8));
auto f4 = flash::fp8e4m3_to_fp32(static_cast<__hip_fp8_storage_t>((fp8x2_high) >> 8));
#endif
f1 *= scales[k_idx / 2];
f2 *= scales[k_idx / 2];
f3 *= scales[k_idx / 2];
f4 *= scales[k_idx / 2];
// if (block0)
// {
// printf(" tidx = %d %.4f %.4f %.4f %.4f \n", threadIdx.x, f1, f2, f3, f4);
// }
cutlass::NumericConverter<Element, float, cutlass::FloatRoundStyle::round_toward_zero> convert_;
auto rst0 = convert_(f1);
auto rst1 = convert_(f2);
auto rst2 = convert_(f3);
auto rst3 = convert_(f4);
tSrK(j, 0, k_idx) = rst0; tSrK(j, 0, k_idx) = rst0;
tSrK(j + 1, 0, k_idx) = rst1; tSrK(j + 1, 0, k_idx) = rst1;
tSrK(j + 2, 0, k_idx) = rst2; tSrK(j + 2, 0, k_idx) = rst2;
...@@ -333,6 +404,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -333,6 +404,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
flash::__ds_read_m32x16_row_col_rrow<0, 2, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow<0, 3, 0>(tOsVt, tOrVt_copy_view);
if constexpr (MODEL_TYPE == ModelType::V32)
{ {
bf16_storage bf16_data0; bf16_storage bf16_data0;
bf16_storage bf16_data1; bf16_storage bf16_data1;
...@@ -360,12 +432,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -360,12 +432,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor tScS = thr_mma.partition_C(cS); Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) { for (int i = 0; i < size(acc_s); ++i) {
{ int idx = int(get<1>(tScS(i))) + block_idx * TOPK_BLOCK_SIZE;
int idx = int(get<1>(tScS(i))) + block_idx * TOPK_BLOCK_SIZE; idx = gIndices[idx] ;
idx = gIndices[idx] ; if (idx == -1) acc_s(i) = -INFINITY;
if (idx == -1) acc_s(i) = -INFINITY;
}
} }
block_idx == 0 block_idx == 0
? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2) ? softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2)
...@@ -399,9 +468,17 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -399,9 +468,17 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
} }
}; };
MainloopArgs args = get_cur_req_info(batch_idx); if constexpr (MODEL_TYPE == ModelType::V32) {
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
process_one_block(block_idx, IsOrigBlock{}); process_one_block(block_idx, IsOrigBlock{});
}
} else {
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{});
}
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{});
}
} }
if (args.is_no_split) { if (args.is_no_split) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment