Commit 7efb944d authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp8

parent 72b2aea0
...@@ -905,15 +905,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -905,15 +905,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[i].data; *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[i].data;
kv_lds_write_ptr += 64*64; kv_lds_write_ptr += 64*64;
} }
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_qkvfp8<false, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
} }
// if (block0()) // if (block0())
// { // {
...@@ -930,25 +921,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -930,25 +921,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
__syncthreads(); __syncthreads();
for (int i = 0; i < 8; i++) {
for (int i = 0; i < 7; i++) {
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i)); cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
} }
// cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
// cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
Fp8_storage v3_0, v3_1; Fp8_storage v3_0, v3_1;
__ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data); __ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data);
__ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data); __ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data);
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
{ {
intx4_t* d = reinterpret_cast<intx4_t*>(&tSrK(0, 0, 8)); intx4_t* d = reinterpret_cast<intx4_t*>(&tSrK(0, 0, 8));
*d = kv_data[8].data; *d = kv_data[8].data;
} }
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
for (int i = 0; i < 8; i++) {
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
}
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
// if (thread0()) { // if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3)); // printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment