Commit 2ff340aa authored by zhanghj2's avatar zhanghj2
Browse files

优化tp8 nmz 代码

parent 4d897ed1
...@@ -547,6 +547,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -547,6 +547,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const int n_split_idx, const int seqlen_k, const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit, const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) { SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
if (n_block_max <= n_block_min) {
return;
}
constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
...@@ -872,16 +875,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -872,16 +875,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f; c3_0.x = 0.0f; c3_0.y = 0.0f; c3_0.z = 0.0f; c3_0.w = 0.0f;
c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f; c3_1.x = 0.0f; c3_1.y = 0.0f; c3_1.z = 0.0f; c3_1.w = 0.0f;
// #pragma unroll extern __shared__ char shared_memory[];
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) { struct IsMaskBlock {};
struct IsFirstMaskBlock {};
struct IsNoMaskBlock {};
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
clear(acc_s); clear(acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + block_idx;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[block_idx - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
...@@ -889,17 +899,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -889,17 +899,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
#if 1 #if 1
lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - block_idx * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - block_idx * kBlockN);
constexpr static int BUFFER_SIZE = 1; constexpr static int BUFFER_SIZE = 1;
uint128_t buffer[BUFFER_SIZE]; uint128_t buffer[BUFFER_SIZE];
buffer_load_copy_qkvfp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy_qkvfp8<false, true, true, true>(gK, buffer[0], 8, params.k_row_stride, offset_k, seqlen_k - block_idx * kBlockN);
asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
...@@ -937,21 +947,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -937,21 +947,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
#else #else
#endif #endif
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS); Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) { for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) { if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY;
} else { } else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i))); int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
} }
} }
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1) // if constexpr (n_masking_steps == 1)
// { // {
...@@ -959,10 +971,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -959,10 +971,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
// } // }
// else // else
{ {
const bool is_first_masking_step = masking_step == 0;
is_first_masking_step softmax.template softmax_rescale_o_fp8</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
? softmax.template softmax_rescale_o_fp8</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1)
: softmax.template softmax_rescale_o_fp8</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
} }
...@@ -1025,7 +1035,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -1025,7 +1035,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
};
if constexpr (n_masking_steps == 1) {
process_one_block(n_block, IsFirstMaskBlock{});
n_block--;
} else {
int masking_step = 1;
process_one_block(n_block, IsFirstMaskBlock{});
n_block--;
for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) {
process_one_block(n_block, IsMaskBlock{});
}
}
for(; n_block >= n_block_min; --n_block) {
process_one_block(n_block, IsNoMaskBlock{});
} }
#endif #endif
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w; acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment