Commit 4d897ed1 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz tp1性能

parent 3722ec71
......@@ -683,9 +683,9 @@ mha_fwd_kvcache_mla_fp8(
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'");
// static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'");
setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// std::cout << FLASH_MLA_ROOT_DIR << "\n";
// exit(-1);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
......
......@@ -1299,6 +1299,41 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads();
auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64;
constexpr static int k_read_lds_offset = 32 * 64;
// Fp8_storage data[9];
#if 0
Fp8_storage k_data[9];
__builtin_amdgcn_sched_barrier(0);
k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096, 3, 1, 0);
k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096, 3, 1, 0);
k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096, 3, 1, 0);
k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096, 3, 1, 0);
k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096, 3, 1, 0);
k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096, 3, 1, 0);
k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096, 3, 1, 0);
k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096, 3, 1, 0);
k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096, 3, 1, 0);
#pragma unroll
for (int i = 0; i < 9; i++) {
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[0], true, false);
}
k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096 + k_read_lds_offset, 3, 1, 0);
#pragma unroll
for (int i = 0; i < 9; i++) {
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[1], true, false);
}
__builtin_amdgcn_sched_barrier(0);
#else
{
constexpr static int k_idx = 0;
// k_lds_read_ptr += k_idx * 64 * 64;
......@@ -1383,7 +1418,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 6;
// k_lds_read_ptr += 64 * 64;
......@@ -1425,6 +1459,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
#endif
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
......@@ -1534,299 +1569,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// asm volatile("s_barrier \n\t");
};
#if 0
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
v4f accs_f32[2];
for (int i = 0; i < 2; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + block_idx;
// cur_block_table = block_table[block_idx - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
// gK.data() = gK.data() + (offset_k);
gK.data() = gK.data() + (offset_k);
auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad);
auto k_lds_addr = reinterpret_cast<size_t>(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64);
if (block_idx * kBlockN + ((warp_id) % 4) * 16 < seqlen_k || IS_NO_MASK_BLOCK)
{
k_lds_addr |= 0x80000000;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256+128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
if (warp_id < 4)
{
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 512, 1, 1, 0, 0);
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 0);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 1);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 2);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 3);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
gK.data() = gK.data() + ( - offset_k);
auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64;
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
constexpr static int k_read_lds_offset = 32 * 64;
{
constexpr static int k_idx = 0;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 1;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 2;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 3;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 4;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 5;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 6;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 7;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 8;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
#endif
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
// #endif
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
}
// asm volatile("s_barrier \n\t");
softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
// #if 1
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t result1;
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false);
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64]));
*lds_ptr = result;
int32_t* lds_ptr1 = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64 + 8]));
*lds_ptr1 = result1;
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 4) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
int lane_id = tidx % 64;
int row = lane_id / 4;
int col = lane_id % 4;
col = (col + (row / 2) % 4) % 4;
auto lds_offset = row * 64 + col * 16 + (warp_id / 4) * 64 * 64;
// Fp8_storage v0_0, v0_1;
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
for (int n = 0; n < 4; n++)
{
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[0], v, acco_f32[n * 4 + j], true, false);
}
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128 + 32 * 64));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128 + 32 * 64));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[1], v, acco_f32[n * 4 + j], true, false);
}
}
}
asm volatile("s_barrier \n\t");
};
#endif
#if 1
if constexpr (n_masking_steps == 1) {
......@@ -2824,9 +2566,9 @@ template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla_fp8_tp1(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
const static bool disable_asm = get_env_("FLASH_MLA_DISABLE_ASM");
// const static bool enable_asm = get_env_("FLASH_MLA_ENABLE_ASM");
if (disable_asm) {
if (1) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel_fp8_tp1<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = 65536;
......@@ -2834,54 +2576,54 @@ void run_flash_splitkv_fwd_mla_fp8_tp1(Flash_fwd_mla_params &params, cudaStream_
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
}
else {
static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR");
assert(FLASH_MLA_ASM_DIR != nullptr && "FLASH_MLA_ASM_DIR nullptr \n");
constexpr size_t smem_size = 65536;
std::string co_file = std::string(FLASH_MLA_ASM_DIR) +
"flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co";
hipError_t status = hipSuccess;
static hipModule_t fwd_module_sample;
static bool IS_FWD_MODULE_LOADED = false;
// else {
// static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR");
// assert(FLASH_MLA_ASM_DIR != nullptr && "FLASH_MLA_ASM_DIR nullptr \n");
// constexpr size_t smem_size = 65536;
// std::string co_file = std::string(FLASH_MLA_ASM_DIR) +
// "flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co";
// hipError_t status = hipSuccess;
// static hipModule_t fwd_module_sample;
// static bool IS_FWD_MODULE_LOADED = false;
if (IS_FWD_MODULE_LOADED == false)
{
status = hipModuleLoad(&fwd_module_sample, co_file.c_str());
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str());
return;
}
IS_FWD_MODULE_LOADED = true;
}
size_t params_size = sizeof(params);
void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER,
&params,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&params_size,
HIP_LAUNCH_PARAM_END
};
dim3 grid(num_m_block, params.h, params.num_sm_parts);
std::string kernel_name = params.is_causal ?
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params":
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params";
hipFunction_t flash_mla_func;
status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str());
status = hipModuleLaunchKernel(
flash_mla_func,
grid.x, grid.y, grid.z,
Kernel_traits::kNThreads, 1, 1,
smem_size, // shared memory
stream, // stream
NULL,
(void**)&config
);
if (status not_eq hipSuccess) {
printf("[flashmla] EXIT: failed to launch kernel!\n");
return;
}
}
// if (IS_FWD_MODULE_LOADED == false)
// {
// status = hipModuleLoad(&fwd_module_sample, co_file.c_str());
// if (status not_eq hipSuccess) {
// printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str());
// return;
// }
// IS_FWD_MODULE_LOADED = true;
// }
// size_t params_size = sizeof(params);
// void* config[] = {
// HIP_LAUNCH_PARAM_BUFFER_POINTER,
// &params,
// HIP_LAUNCH_PARAM_BUFFER_SIZE,
// &params_size,
// HIP_LAUNCH_PARAM_END
// };
// dim3 grid(num_m_block, params.h, params.num_sm_parts);
// std::string kernel_name = params.is_causal ?
// "_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params":
// "_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params";
// hipFunction_t flash_mla_func;
// status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str());
// status = hipModuleLaunchKernel(
// flash_mla_func,
// grid.x, grid.y, grid.z,
// Kernel_traits::kNThreads, 1, 1,
// smem_size, // shared memory
// stream, // stream
// NULL,
// (void**)&config
// );
// if (status not_eq hipSuccess) {
// printf("[flashmla] EXIT: failed to launch kernel!\n");
// return;
// }
// }
CHECK_CUDA_KERNEL_LAUNCH();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment