"wrappers/python/vscode:/vscode.git/clone" did not exist on "82703dff7395dbc80d320af2d101d3ea530d2a25"
Commit 34489f46 authored by zhanghj2's avatar zhanghj2
Browse files

优化代码

parent 98b7c697
......@@ -1052,6 +1052,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage,const float descale_k, const float scale_softmax, const float scale_softmax_log2) {
if (n_block_max <= n_block_min) {
return;
}
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
......@@ -1196,9 +1199,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
constexpr static int STAGE = 8;
#if 1
for (int masking_step = 0; n_block >= n_block_min; ++masking_step, --n_block) {
struct IsMaskBlock {};
struct IsFirstMaskBlock {};
struct IsNoMaskBlock {};
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
v4f accs_f32[2];
for (int i = 0; i < 2; i++)
{
......@@ -1207,30 +1214,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1];
const int *cur_block_table_ptr = block_table + block_idx;
// cur_block_table = block_table[block_idx - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
// gK.data() = gK.data() + (offset_k);
#if 1
gK.data() = gK.data() + (offset_k);
auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
// const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad);
auto k_lds_addr = reinterpret_cast<size_t>(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64);
if (n_block * kBlockN + ((warp_id) % 4) * 16 < seqlen_k || masking_step != 0)
if (block_idx * kBlockN + ((warp_id) % 4) * 16 < seqlen_k || IS_NO_MASK_BLOCK)
{
k_lds_addr |= 0x80000000;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0);
......@@ -1405,33 +1408,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
// #endif
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY;
if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
// asm volatile("s_barrier \n\t");
{
const bool is_first_masking_step = masking_step == 0;
// is_first_masking_step
// ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2)
// : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2);
is_first_masking_step
? softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/true, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32)
: softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/false, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
}
#if 1
// asm volatile("s_barrier \n\t");
softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
// #if 1
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
......@@ -1498,31 +1494,42 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
asm volatile("s_barrier \n\t");
#endif
};
#if 1
if constexpr (n_masking_steps == 1) {
process_one_block(n_block, IsFirstMaskBlock{});
n_block--;
} else {
int masking_step = 1;
process_one_block(n_block, IsFirstMaskBlock{});
n_block--;
for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) {
process_one_block(n_block, IsMaskBlock{});
}
}
for(; n_block >= n_block_min; --n_block) {
process_one_block(n_block, IsNoMaskBlock{});
}
#endif
using ElementO = typename Kernel_traits::ElementO;
using ElementAccum = typename Kernel_traits::ElementAccum;
const int split_offset = __ldg(params.num_splits_ptr + bidb);
// Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
if (NoSplit) {
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
if (NoSplit) {
constexpr bool Split = false;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
......@@ -1598,13 +1605,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
} else {
constexpr bool Split = true;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
const int split_offset = __ldg(params.num_splits_ptr + bidb);
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor lse = softmax.template normalize_softmax_lse_fp8_tp1</*Is_dropout=*/false, Split, true>(acco_f32, sRow_sum_reduce_buffer, scale_softmax, descale_k);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + (row_offset_lseaccum)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment