Commit 29b0d13b authored by zhanghj2's avatar zhanghj2
Browse files

优化sparse decode fp8

parent 4648ec2f
......@@ -627,14 +627,14 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
__syncthreads();
__builtin_amdgcn_sched_barrier(0);
flash::__ds_read_m32x16_row_col_rrow<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 0, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<1, 3, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<1, 0, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<1, 1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<1, 2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<1, 3, 3>(tOsVt, tOrVt_copy_view);
__syncthreads();
// if (block0() && threadIdx.x >= 192)
......@@ -681,10 +681,10 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
__syncthreads();
flash::__ds_read_m32x16_row_col_rrow<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 0>(tOsVt, tOrVt_copy_view);
if constexpr (MODEL_TYPE == ModelType::V32)
{
......@@ -735,21 +735,21 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
{
// __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_alt<1, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_alt<1, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o);
cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_alt<1, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_alt<1, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
......@@ -774,6 +774,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// {
// printf(" %.4f %.4f \n", acc_o(0), acc_o(1));
// }
auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32 += 0x8000u;
#endif
return uint16_t(x32 >> 16);
};
if (args.is_no_split) {
int start_head_idx = head_block_idx*BLOCK_M;
Tensor lse = softmax.template normalize_softmax_lse<false>(acc_o, sRow_sum_reduce_buffer, params.sm_scale);
......@@ -805,10 +815,24 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
}
// if (block0() && tidx % 16 == 0)
// {
// printf(" tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %3f \n ",
// tidx,
// float(acc_o(0)),
// float(acc_o(1)),
// float(acc_o(2)),
// float(acc_o(3)),
// float(acc_o(4)),
// float(acc_o(5)),
// float(acc_o(6)),
// float(acc_o(7))
// );
// }
float* gSoftmaxLse = (float*)params.lse + batch_idx * params.stride_lse_b + start_head_idx + s_q_idx * params.stride_lse_s_q; // (BLOCK_M) : (1)
{
auto rO = flash::convert_type<Element>(acc_o);
// auto rO = flash::convert_type<Element>(acc_o);
using result_type = cutlass::Array<Element, 2>;
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
......@@ -829,13 +853,29 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
|
|
*/
col = (laneId / 16) + ni * 128 + (warpId % 2) * 8 + (warpId / 2) * 64;
col = (laneId / 16) * 2 + ni * 128 + (warpId % 2) * 8 + (warpId / 2) * 64;
for (int i = 0; i < 4; i ++) {
for (int j = 0; j < 2; j++) {
gO(row, col) = rO(i * 2 + j, mi, ni);
col += 4;
}
col += 8;
#if defined(__gfx938__)
auto d = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(i, mi, ni), 0, acc_o(i + 4, mi, ni), 0);
auto res = reinterpret_cast<result_type const &>(d);
#else
result_type res;
Element e0, e1;
e0.storage = float2bf16(acc_o(i, mi, ni));
e1.storage = float2bf16(acc_o(i + 4, mi, ni));
res[0] = e0;
res[1] = e1;
#endif
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*(result_type*)(&gO(row, col)) = res;
col += 16;
// for (int j = 0; j < 2; j++) {
// gO(row, col) = rO(i * 2 + j, mi, ni);
// col += 4;
// }
// col += 8;
}
// for (int ei = 0; ei < size<0>(acc_o); ++ei) {
// gO(row, col) = rO(ei, mi, ni);
......@@ -883,13 +923,15 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// gOaccum(row, col) = acc_o(ei, mi, ni);
// col += 4;
// }
col = (laneId / 16) + ni * 128 + (warpId % 2) * 8 + (warpId / 2) * 64;
col = (laneId / 16) * 2 + ni * 128 + (warpId % 2) * 8 + (warpId / 2) * 64;
for (int i = 0; i < 4; i ++) {
for (int j = 0; j < 2; j++) {
gOaccum(row, col) = acc_o(i * 2 + j, mi, ni);
col += 4;
}
col += 8;
gOaccum(row, col) = acc_o(i, mi, ni);
gOaccum(row, col + 1) = acc_o(i + 4, mi, ni);
// for (int j = 0; j < 2; j++) {
// gOaccum(row, col) = acc_o(i * 2 + j, mi, ni);
// col += 4;
// }
col += 16;
}
}
......
......@@ -256,6 +256,28 @@ __forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1&
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col_alt(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
inline __device__ float fp8e4m3_to_fp32(const fp8& input) {
const uint32_t w = (uint32_t)input << 24;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment