Commit 3722ec71 authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz fp8 tp1

parent 34489f46
......@@ -1075,7 +1075,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));//64*576
const auto gK_data = gK.data();
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDimV>>{},
make_stride(params.k_row_stride, _1{}));//64*512
......@@ -1099,6 +1099,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);//16*32*32
union Fp8_storage
{
// uint32x4_t val;
intx4_t data;
intx2_t p[2];
int32_t fp8_array[4];
......@@ -1198,10 +1199,341 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32[i].w = 0.0f;
}
constexpr static int STAGE = 8;
extern __shared__ char shared_memory[];
struct IsMaskBlock {};
struct IsFirstMaskBlock {};
struct IsNoMaskBlock {};
struct IsLastBlock {};
int lane_id = tidx % 64;
int row = lane_id / 4;
int col = lane_id % 4;
col = (col + (row / 2) % 4) % 4;
const auto lds_offset = row * 64 + col * 16 + (warp_id / 4) * 64 * 64;
uint8_t* kv_lds_write_ptr_base = reinterpret_cast<uint8_t*>(shared_memory) +
((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64 + row * 64 + col * 16;
Fp8_storage kv_data[5];
{
int cur_block_table;
// const int *cur_block_table_ptr;
cur_block_table = block_table[n_block];
index_t offset_k;
//gK.data() = gK_data + (offset_k);
// cur_block_table_ptr = block_table + n_block;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
// buffer_load_copy_fp8_tp1<false, true, 0>(gK, kv_data[0].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_fp8_tp1<false, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - n_block * kBlockN);
// uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base;
// // for (int i = 0; i < )
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[0].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[1].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[2].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data;
// kv_lds_write_ptr += 64 * 128;
// *(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data;
// kv_lds_write_ptr += 64 * 128;
// gK.data() = gK.data() + (offset_k);
auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad);
auto k_lds_addr = reinterpret_cast<size_t>(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64);
if (n_block * kBlockN + ((warp_id) % 4) * 16 < seqlen_k)
{
k_lds_addr |= 0x80000000;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256+128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
if (warp_id < 4)
{
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 512, 1, 1, 0, 0);
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 0);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 1);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 2);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 3);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
}
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_LAST_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsLastBlock>;
v4f accs_f32[2];
for (int i = 0; i < 2; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
__syncthreads();
auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64;
constexpr static int k_read_lds_offset = 32 * 64;
{
constexpr static int k_idx = 0;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 1;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 2;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 3;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 4;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 5;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 6;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 7;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 8;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
}
softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t result1;
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false);
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64]));
*lds_ptr = result;
int32_t* lds_ptr1 = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64 + 8]));
*lds_ptr1 = result1;
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 4) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
}
if (block_idx > n_block_min) {
int cur_block_table;
const int *cur_block_table_ptr;
cur_block_table = block_table[block_idx - 1];
index_t offset_k;
// cur_block_table_ptr = block_table + block_idx;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
buffer_load_copy_fp8_tp1<true, true, 0>(gK, kv_data[0].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 1>(gK, kv_data[1].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 2>(gK, kv_data[2].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, true, 3>(gK, kv_data[3].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
buffer_load_copy_fp8_tp1<true, false, 4>(gK, kv_data[4].data, params.k_row_stride, seqlen_k - block_idx * kBlockN);
}
for (int n = 0; n < 4; n++)
{
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[0], v, acco_f32[n * 4 + j], true, false);
}
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128 + 32 * 64));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128 + 32 * 64));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[1], v, acco_f32[n * 4 + j], true, false);
}
}
if (block_idx > n_block_min) {
__syncthreads();
uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base;
// for (int i = 0; i < )
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[0].data;
kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[1].data;
kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[2].data;
kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[3].data;
kv_lds_write_ptr += 64 * 128;
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[4].data;
}
// asm volatile("s_barrier \n\t");
};
#if 0
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
......@@ -1494,7 +1826,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
asm volatile("s_barrier \n\t");
};
};
#endif
#if 1
if constexpr (n_masking_steps == 1) {
process_one_block(n_block, IsFirstMaskBlock{});
......
......@@ -2748,6 +2748,59 @@ lds_direct_copy_qkvfp8_zero_lds(
#endif
}
template <
bool Is_even_MN=true,
bool Is_even_K=true,
int k_idx,
class SrcEngine, class SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy_fp8_tp1(
Tensor<SrcEngine, SrcLayout> const& src,
intx4_t & dst,
const int row_stride,
const int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 1;
constexpr int elements_per_thread = 16;
{
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = lane / 4;
int col = lane % 4;
int row_offset = row + ((warp_id % 4) * 16) ;
int col_offset = col * elements_per_thread + k_idx * 128 + (warp_id / 4) * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_K && col_offset >=576) offset_v = -1;
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
{
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment