Commit f01246b6 authored by zhanghj2's avatar zhanghj2
Browse files

优化prefill sparse写出

parent c3a5b02a
......@@ -204,10 +204,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 0>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 4, params.stride_kv_s_kv, params.s_kv);
......@@ -233,10 +233,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
......@@ -263,10 +263,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 2>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
......@@ -332,7 +332,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
{
flash::__ds_read_m32x16_row_col_rrow<0, 0, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 3>(tOsVt, tOrVt_copy_view);
......@@ -344,7 +344,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
......@@ -355,9 +355,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o);
// for (int i = 0; i < size(tOrP); i++)
......@@ -449,7 +449,18 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
{
// store O and gLSE
auto rO = flash::convert_type<Element>(acc_o);
// auto rO = flash::convert_type<Element>(acc_o);
auto float2bf16 = [] (float s) -> uint16_t {
uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32 += 0x8000u;
#endif
return uint16_t(x32 >> 16);
};
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
......@@ -457,11 +468,54 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
row = mi * kBlockM + laneId % 16;
if (row < params.h_q) {
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16) + ni * 128 + warpId * 32 ;
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
gO(row, col) = rO(ei, mi, ni);
col += 4;
col = (laneId / 16) * 2 + ni * 128 + warpId * 32 ;
using result_type = cutlass::Array<Element, 2>;
for (int ei = 0; ei < 4; ei++)
{
#if defined(__gfx938__)
auto d = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei + 4, mi, ni), 0);
auto res = reinterpret_cast<result_type const &>(d);
#else
result_type res;
Element e0, e1;
e0.storage = float2bf16(acc_o(ei, mi, ni));
e1.storage = float2bf16(acc_o(ei + 4, mi, ni));
res[0] = e0;
res[1] = e1;
#endif
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*(result_type*)(&gO(row, col)) = res;
col += 8;
}
// gO(row, col) = rO(0, mi, ni);
// gO(row, col + 1) = rO(1, mi, ni);
// col += 8;
// gO(row, col) = rO(2, mi, ni);
// gO(row, col + 1) = rO(3, mi, ni);
// col += 8;
// gO(row, col) = rO(4, mi, ni);
// gO(row, col + 1) = rO(5, mi, ni);
// col += 8;
// gO(row, col) = rO(6, mi, ni);
// gO(row, col + 1) = rO(7, mi, ni);
// gO(row, col) = rO(0, mi, ni);
// gO(row, col + 1) = rO(4, mi, ni);
// col += 8;
// gO(row, col) = rO(1, mi, ni);
// gO(row, col + 1) = rO(5, mi, ni);
// col += 8;
// gO(row, col) = rO(2, mi, ni);
// gO(row, col + 1) = rO(6, mi, ni);
// col += 8;
// gO(row, col) = rO(3, mi, ni);
// gO(row, col + 1) = rO(7, mi, ni);
// for (int ei = 0; ei < size<0>(acc_o); ei += 2) {
// gO(row, col) = rO(ei, mi, ni);
// col += 4;
// }
}
gLSE[row] = lse(mi);
gMax_logits[row] = topk_length == 0 ? -INFINITY : softmax.row_max(mi) * params.sm_scale;
......
......@@ -211,6 +211,28 @@ __forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Ten
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int row, int col, int r_row, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col_rrow_alt(Tensor0& src, Tensor1& dst)
{
auto lds = reinterpret_cast<__fp16 *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 2;
auto d = __builtin_amdgcn_ds_read_m32x16f16_alt((__attribute__((address_space(3))) __fp16*)(lds), offset);
uint16_t * d_ptr = reinterpret_cast<uint16_t*>(&d);
uint16_t * dst_ptr = reinterpret_cast<uint16_t*>(&(dst(0, r_row, col)));
dst_ptr[0] = d_ptr[0];
dst_ptr[1] = d_ptr[1];
dst_ptr[2] = d_ptr[2];
dst_ptr[3] = d_ptr[3];
dst_ptr[4] = d_ptr[4];
dst_ptr[5] = d_ptr[5];
dst_ptr[6] = d_ptr[6];
dst_ptr[7] = d_ptr[7];
}
template<int row, int col, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1& dst)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment