Commit 79096f6b authored by zhanghj2's avatar zhanghj2
Browse files

使用__builtin_hcu_ds_read_m32x32_i8_alt2指令

parent 4b3bcb50
...@@ -472,23 +472,23 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params, ...@@ -472,23 +472,23 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params,
Tensor lse = softmax.template normalize_softmax_lse_fp8</*Is_dropout=*/false, Split>(tOrO, sRow_sum_reduce_buffer, scale_softmax, descale_k); Tensor lse = softmax.template normalize_softmax_lse_fp8</*Is_dropout=*/false, Split>(tOrO, sRow_sum_reduce_buffer, scale_softmax, descale_k);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>; using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // // Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t< // using SmemTiledCopyO = std::conditional_t<
!Split, // !Split,
typename Kernel_traits::SmemCopyAtomO, // typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum // typename Kernel_traits::SmemCopyAtomOaccum
>; // >;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); // auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); // auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor rO = flash::convert_type<ElementO>(tOrO); Tensor rO = flash::convert_type<ElementO>(tOrO);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) // Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// __syncthreads(); // // __syncthreads();
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); // cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
...@@ -501,20 +501,20 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params, ...@@ -501,20 +501,20 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params,
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>; // using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO gmem_tiled_copy_Oaccum; // GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); // auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) // Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); // Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads(); // __syncthreads();
// if (tidx >= kNThreadsS) { return; } // if (tidx >= kNThreadsS) { return; }
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum)); // Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); // cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
...@@ -528,15 +528,47 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params, ...@@ -528,15 +528,47 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params &params,
} }
} }
// Construct identity layout for sO // // Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts // // Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) // Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); // Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem // // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM // gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
); // );
{
int tidx = threadIdx.x;
int col = 0;
for (int m = 0; m < size<1>(rO); m++)
{
const int row = get<0>(taccOcO(0, m, 0));
if (row < params.seqlen_q - m_block * kBlockM)
{
for (int n = 0; n < size<2>(rO); n++)
{
// col = (tidx % 64 / 16) * 4 + (tidx / 64) * 32 + n * 128;
// for (int ei = 0; ei < 8; ei += 4) {
// gOaccum(row, col) = rO(ei, m, n);
// gOaccum(row, col + 1) = rO(ei + 1, m, n);
// gOaccum(row, col + 2) = rO(ei + 2, m, n);
// gOaccum(row, col + 3) = rO(ei + 3, m, n);
// col += 16;
// }
col = (tidx % 64 / 16) * 8 + (tidx / 64) * 32 + n * 128;
gOaccum(row, col) = rO(0, m, n);
gOaccum(row, col + 1) = rO(1, m, n);
gOaccum(row, col + 2) = rO(2, m, n);
gOaccum(row, col + 3) = rO(3, m, n);
gOaccum(row, col + 4) = rO(4, m, n);
gOaccum(row, col + 5) = rO(5, m, n);
gOaccum(row, col + 6) = rO(6, m, n);
gOaccum(row, col + 7) = rO(7, m, n);
}
}
}
}
} }
...@@ -963,11 +995,15 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -963,11 +995,15 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
int32_t result; int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0), acc_s(1), result, false); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0), acc_s(1), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2), acc_s(3), result, true); result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2), acc_s(3), result, true);
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&sP[ (tid) * 16 + warp_id * 4]);
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0])); // int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0]));
*lds_ptr = result; *lds_ptr = result;
__syncthreads(); __syncthreads();
data_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16])); data_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16]));
// data_fp8.bf16[0] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4]));
// data_fp8.bf16[1] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 64 * 4]));
// data_fp8.bf16[2] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 2 * 64 * 4]));
// data_fp8.bf16[3] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 3 * 64 * 4]));
} }
if (block_idx > n_block_min) { if (block_idx > n_block_min) {
...@@ -999,6 +1035,28 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -999,6 +1035,28 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
__ds_read_m32x32_row_col_rrow<0, 0, 0>(tOsVt, v0_0.data); __ds_read_m32x32_row_col_rrow<0, 0, 0>(tOsVt, v0_0.data);
__ds_read_m32x32_row_col_rrow<1, 0, 1>(tOsVt, v1_0.data); __ds_read_m32x32_row_col_rrow<1, 0, 1>(tOsVt, v1_0.data);
__ds_read_m32x32_row_col_rrow<2, 0, 2>(tOsVt, v2_0.data); __ds_read_m32x32_row_col_rrow<2, 0, 2>(tOsVt, v2_0.data);
// if (block0() && tidx < 64) {
// auto res0 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[0], false);
// auto res1 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[0], true);
// auto res2 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[1], false);
// auto res3 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[1], true);
// auto res4 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[2], false);
// auto res5 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[2], true);
// auto res6 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[3], false);
// auto res7 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[3], true);
// printf(" %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n", tidx, res0[0], res0[1],
// res1[0], res1[1],
// res2[0], res2[1],
// res3[0], res3[1],
// res4[0], res4[1],
// res5[0], res5[1],
// res6[0], res6[1],
// res7[0], res7[1]
// );
// }
c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[0], c3_0, true, false); c3_0 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[0], c3_0, true, false);
c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[1], c3_1, true, false); c3_1 = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(data_fp8.p[0], v3_0.p[1], c3_1, true, false);
...@@ -1062,14 +1120,34 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -1062,14 +1120,34 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
#endif #endif
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w; acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_1.x;
acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w; acc_o(2, 0, 0) = c0_0.y; acc_o(3, 0, 0) = c0_1.y;
acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w; acc_o(4, 0, 0) = c0_0.z; acc_o(5, 0, 0) = c0_1.z;
acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w; acc_o(6, 0, 0) = c0_0.w; acc_o(7, 0, 0) = c0_1.w;
acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w;
acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w; acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_1.x;
acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w; acc_o(2, 0, 1) = c1_0.y; acc_o(3, 0, 1) = c1_1.y;
acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w; acc_o(4, 0, 1) = c1_0.z; acc_o(5, 0, 1) = c1_1.z;
acc_o(6, 0, 1) = c1_0.w; acc_o(7, 0, 1) = c1_1.w;
acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_1.x;
acc_o(2, 0, 2) = c2_0.y; acc_o(3, 0, 2) = c2_1.y;
acc_o(4, 0, 2) = c2_0.z; acc_o(5, 0, 2) = c2_1.z;
acc_o(6, 0, 2) = c2_0.w; acc_o(7, 0, 2) = c2_1.w;
acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_1.x;
acc_o(2, 0, 3) = c3_0.y; acc_o(3, 0, 3) = c3_1.y;
acc_o(4, 0, 3) = c3_0.z; acc_o(5, 0, 3) = c3_1.z;
acc_o(6, 0, 3) = c3_0.w; acc_o(7, 0, 3) = c3_1.w;
// acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w;
// acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w;
// acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w;
// acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w;
// acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w;
// acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w;
// acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w;
// acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w;
if (NoSplit) if (NoSplit)
store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax); store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
......
...@@ -2709,13 +2709,20 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Ten ...@@ -2709,13 +2709,20 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Ten
template<int row, int col, int r_row, typename Tensor0> template<int row, int col, int r_row, typename Tensor0>
__forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst) __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, intx4_t& dst)
{ {
#if 0
auto lds = reinterpret_cast<int *>(src.data().get()); auto lds = reinterpret_cast<int *>(src.data().get());
auto layout = src.layout(); auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1; constexpr short offset = layout(0, row, col) * 1;
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset); auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
dst = d; dst = d;
#else
auto lds = reinterpret_cast<uint8_t *>(src.data().get());
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1;
lds += offset;
dst = __builtin_hcu_ds_read_m32x32_i8_alt2((__attribute__((address_space(3))) int*)(lds));
#endif
} }
#endif #endif
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment