Commit 98b7c697 authored by zhanghj2's avatar zhanghj2
Browse files

fp8 tp1性能提升

parent 24c52aee
......@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
const int warp_id = tidx / 64;
const int warp_id = __builtin_amdgcn_readfirstlane(tidx / 64);
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
......@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
intx2_t p[2];
int32_t fp8_array[4];
};
#if 0
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
__syncthreads();
#else
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, true>(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp1<false, false>(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
uint8_t* q_lds_read_ptr = reinterpret_cast<uint8_t*>(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 4) * (16 * 64);
{
int k = 0;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
Fp8_storage q_r[9];
#if 1
auto gQ_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.q_row_stride;
const int q_zero_pad = std::min(std::max(m_block * kBlockM + ((warp_id) % 4 + 1) * 16 - params.seqlen_q, 0), 16);
uint32x4_t gQ_rscr = make_rscr((unsigned char*)(gQ.data().get() + gQ_offset), params.q_row_stride, q_zero_pad);
auto q_lds_addr = reinterpret_cast<size_t>(sQ.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64) | 0x80000000;
if (m_block * kBlockM + ((warp_id) % 4) * 16 < params.seqlen_q)
{
__builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 0, 1, 1, 0, 0);
q_lds_addr += 64*128;
__builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 128, 1, 1, 0, 0);
q_lds_addr += 64*128;
__builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 256, 1, 1, 0, 0);
q_lds_addr += 64*128;
__builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 256+128, 1, 1, 0, 0);
q_lds_addr += 64*128;
if (warp_id < 4)
{
__builtin_hcu_matrix_load_64x16_b8(gQ_rscr, (__attribute__((address_space(3))) char*)(q_lds_addr), 512, 1, 1, 0, 0);
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
else
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 4);
}
// int k = 0;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 2;
for (int i = 0; i < 16; i++)
else
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 2;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 0);
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 1);
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 2);
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 3);
lds_direct_copy_qkvfp8_zero_lds(gQ, sQ, 4);
}
auto q_lds_read_ptr = sQ.data().get() + (warp_id % 4) * 16 * 64;
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
q_r[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 0, 3, 1, 0);
// q_lds_read_ptr += 64 * 64;
q_r[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 64*64, 3, 1, 0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
// q_lds_read_ptr += 64 * 64;
q_r[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 2*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64;
q_r[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 3*64*64, 3, 1, 0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 4;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 4*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64;
q_r[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 5*64*64, 3, 1, 0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 6;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
q_lds_read_ptr += 64*64;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k+1).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 6*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64;
q_r[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 7*64*64, 3, 1, 0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
{
q_lds_read_ptr += 64*64;
int k = 8;
for (int i = 0; i < 16; i++)
{
tSrQ(i, 0, k).storage = q_lds_read_ptr[i];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 8*64*64, 3, 1, 0);
__syncthreads();
#endif
......@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32[i].w = 0.0f;
}
constexpr static int STAGE = 8;
#if 1
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
clear(acc_s);
v4f accs_f32[2];
for (int i = 0; i < 2; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
......@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
// gK.data() = gK.data() + (offset_k);
#if 1
gK.data() = gK.data() + (offset_k);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8_tp1<false, false>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad);
auto k_lds_addr = reinterpret_cast<size_t>(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64);
if (n_block * kBlockN + ((warp_id) % 4) * 16 < seqlen_k || masking_step != 0)
{
k_lds_addr |= 0x80000000;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256+128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
if (warp_id < 4)
{
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 512, 1, 1, 0, 0);
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 0);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 1);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 2);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 3);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
gK.data() = gK.data() + ( - offset_k);
auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64;
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
constexpr static int k_read_lds_offset = 32 * 64;
{
constexpr static int k_idx = 0;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 1;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
{
constexpr static int k_idx = 2;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 3;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
{
constexpr static int k_idx = 4;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 5;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
{
constexpr static int k_idx = 6;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 7;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8));
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
gK.data() = gK.data() + (-offset_k);
// asm volatile("s_barrier \n\t");
{
constexpr static int k_idx = 8;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
#endif
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
......@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
: softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/false, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
Fp8_storage p_fp8;
{
......@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 4) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
}
{
__builtin_amdgcn_sched_barrier(0);
for (int i = 0; i < 4; i++)
{
{
int k = 0;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
int lane_id = tidx % 64;
int row = lane_id / 4;
int col = lane_id % 4;
col = (col + (row / 2) % 4) % 4;
auto lds_offset = row * 64 + col * 16 + (warp_id / 4) * 64 * 64;
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// Fp8_storage v0_0, v0_1;
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
}
for (int n = 0; n < 4; n++)
{
int k = 2;
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k))));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(&(tOsVt(0, i, k+1))));
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[0], v, acco_f32[n * 4 + j], true, false);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
}
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128 + 32 * 64));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128 + 32 * 64));
for (int j = 0; j < 4; j++)
{
Val tmp;
tmp.data[0] = v0_0.fp8_array[j];
tmp.data[1] = v0_1.fp8_array[j];
acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
}
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[1], v, acco_f32[n * 4 + j], true, false);
}
}
__builtin_amdgcn_sched_barrier(0);
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
asm volatile("s_barrier \n\t");
#endif
}
#endif
using ElementO = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
const int split_offset = __ldg(params.num_splits_ptr + bidb);
......@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
if (NoSplit) {
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
......@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
using result_type = cutlass::Array<bfloat16_t, 2>;
int tidx = threadIdx.x;
......@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum(row, col + 3) = res1[1];
// col += 16;
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
} else {
constexpr bool Split = true;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
......@@ -1616,15 +1622,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
for (int m = 0; m < 1; m++) {
const int row = tidx % 16 + (warpid % 4) * 16;
if (row < params.seqlen_q - m_block * kBlockM) {
// for (int n = 0; n < 32; n++)
// {
// col = (tidx % 64 / 16) * 4 + n * 16;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for (int n = 0; n < 4; n++)
{
col = (tidx % 64 / 16) * 16 + n * 128 + (warp_id / 4) * 64;
......@@ -1658,44 +1655,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum(row, col + 2) = acco_f32[n * 4 + 2].w;
gOaccum(row, col + 3) = acco_f32[n * 4 + 3].w;
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4(const Flash_fwd_mla_params &params,
......
......@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
extern __device__ __attribute__((const)) float __llvm_exp2_f32(float) __asm("llvm.exp2.f32");
__device__ inline uint32x4_t make_rscr(unsigned char* ptr, const int stride, const int zero_pad) {
uint32x4_t rscr;
*(uint64_t*)&rscr = (reinterpret_cast<uint64_t>(ptr));
rscr[2] = stride;
rscr[3] = (1 << 16) & 0XFFFFFFFF;
rscr[3] |= (zero_pad) << 8;
return rscr;
}
template <
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_zero_lds(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 16;
constexpr int bytes_per_warp = warp_size * elements_per_thread * element_size;//64*16*1
int offset_v=-1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + (warp_id % 4) * bytes_per_warp + (k_idx ) * 64*128 * element_size + (warp_id / 4) * 64 * 64;
#if defined(__gfx938__)
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment