Commit 6d3ed1da authored by zhanghj2's avatar zhanghj2
Browse files

优化nmz tp8

parent 2ff340aa
......@@ -563,7 +563,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const int tidx = threadIdx.x;
const int lane_idx = tidx % 64;
const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_k = 0;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));//64*576
......@@ -863,6 +863,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
constexpr static int STAGE = 8;
#if 1
uint8_t* kv_lds_write_ptr_base = reinterpret_cast<uint8_t*>(&tSsK(0, 0 ,0));
v4f c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1;
c0_0.x = 0.0f; c0_0.y = 0.0f; c0_0.z = 0.0f; c0_0.w = 0.0f;
c0_1.x = 0.0f; c0_1.y = 0.0f; c0_1.z = 0.0f; c0_1.w = 0.0f;
......@@ -879,6 +881,40 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
struct IsMaskBlock {};
struct IsFirstMaskBlock {};
struct IsNoMaskBlock {};
const auto gK_data = gK.data();
Fp8_storage kv_data[9];
{
int cur_block_table;
cur_block_table = block_table[n_block];
index_t offset_k;
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
for (int i = 0; i < 8; i++)
{
buffer_load_copy_qkvfp8<false, true, false, false>(gK, kv_data[i].data, i, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
}
buffer_load_copy_qkvfp8<false, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
// __syncthreads();
uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base;
for (int i = 0; i < 8; i++) {
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[i].data;
kv_lds_write_ptr += 64*64;
}
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_qkvfp8<false, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
}
// if (block0())
// {
// printf("threadIdx.x %d kv_lds_write_ptr_base = %p\n ", threadIdx.x, kv_lds_write_ptr_base);
// }
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
......@@ -889,64 +925,30 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
clear(acc_s);
// asm volatile("s_barrier\n\t");
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
int cur_block_table;
const int *cur_block_table_ptr = block_table + block_idx;
// cur_block_table = block_table[block_idx - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k);
#if 1
lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - block_idx * kBlockN);
constexpr static int BUFFER_SIZE = 1;
uint128_t buffer[BUFFER_SIZE];
buffer_load_copy_qkvfp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - block_idx * kBlockN);
__syncthreads();
for (int i = 0; i < 7; i++) {
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
}
// cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
// cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
Fp8_storage v3_0, v3_1;
__ds_read_m32x32_row_col_rrow<3, 0, 3>(tOsVt, v3_0.data);
__ds_read_m32x32_row_col_rrow<3, 1, 3>(tOsVt, v3_1.data);
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n\t");
buffer_to_tensor(buffer[0], tSrK, 8);
{
intx4_t* d = reinterpret_cast<intx4_t*>(&tSrK(0, 0, 8));
*d = kv_data[8].data;
}
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
#else
#endif
gK.data() = gK.data() + (-offset_k);
// if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
// }
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
......@@ -962,27 +964,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
}
// We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1)
// {
// softmax.template softmax_rescale_o_fp8</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
// }
// else
{
softmax.template softmax_rescale_o_fp8</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
}
// Tensor rP = flash::convert_type<Element>(acc_s);
softmax.template softmax_rescale_o_fp8</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
Fp8_storage data_fp8;
// convert_layout_acc_Aregs_fp8(tiled_mma, tiled_mma_o, rP, sP, data_fp8.data);
{
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
......@@ -996,6 +980,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
data_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16]));
}
if (block_idx > n_block_min) {
int cur_block_table;
const int *cur_block_table_ptr;
cur_block_table = block_table[block_idx - 1];
index_t offset_k;
// cur_block_table_ptr = block_table + block_idx;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k);
for (int i = 0; i < 8; i++)
{
buffer_load_copy_qkvfp8<true, true, false, false>(gK, kv_data[i].data, i, params.k_row_stride, 0);
}
buffer_load_copy_qkvfp8<true, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0);
}
{
Fp8_storage v0_0, v0_1;
Fp8_storage v1_0, v1_1;
......@@ -1034,6 +1038,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c2_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[1], v2_1.p[1], c2_1, true, false);
__builtin_amdgcn_sched_barrier(0);
// if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n ", c0_0.x, c0_0.y, c0_0.z, c0_0.w);
// }
}
if (block_idx > n_block_min) {
__syncthreads();
uint8_t* kv_lds_write_ptr = kv_lds_write_ptr_base;
for (int i = 0; i < 8; i++) {
*(reinterpret_cast<intx4_t*>(kv_lds_write_ptr)) = kv_data[i].data;
kv_lds_write_ptr += 64*64;
}
}
};
......
......@@ -570,7 +570,7 @@ CUTE_HOST_DEVICE
void
buffer_load_copy_qkvfp8(
Tensor<SrcEngine, SrcLayout> const& src,
uint128_t & dst,
intx4_t & dst,
int k_idx_, const int row_stride,
int offset_k,
const int max_MN=0)
......@@ -615,11 +615,41 @@ buffer_load_copy_qkvfp8(
);
}
else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
dst = *reinterpret_cast<uint128_t*>(&res);
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
}
} else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;//0-256
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;//0-63
constexpr int element_size = 1;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);//576
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
constexpr int elements_per_thread = 16;
uint32x4_t global_addr = {0};
global_addr[0] = __builtin_amdgcn_readfirstlane(glob_ptr.former);
global_addr[1] = __builtin_amdgcn_readfirstlane(glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int mma_k = 32*64;
int row = lane % 16;
int col = lane / 16;
int row_offset = row + (warp_id * 16);
int col_offset = col * elements_per_thread + k_idx * 64;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment