Commit 7efb944d authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp8

parent 72b2aea0
......@@ -905,15 +905,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[i].data;
kv_lds_write_ptr += 64*64;
}
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_qkvfp8<false, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
}
// if (block0())
// {
......@@ -930,25 +921,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
// asm volatile("s_barrier\n\t");
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
__syncthreads();
for (int i = 0; i < 7; i++) {
for (int i = 0; i < 8; i++) {
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
}
// cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
// cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
}
Fp8_storage v3_0, v3_1;
__ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data);
__ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data);
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
{
intx4_t* d = reinterpret_cast<intx4_t*>(&tSrK(0, 0, 8));
*d = kv_data[8].data;
}
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
for (int i = 0; i < 8; i++) {
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
}
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
// if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment