Commit c3cf875a authored by zhanghj2's avatar zhanghj2
Browse files

实现sparse prefill, 还有bug

parent 50e2de8d
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "params.h" #include "params.h"
// #include "sm90/prefill/sparse/phase1.h" #include "sm90/prefill/sparse/phase1.h"
enum class FwdFeatures : int { enum class FwdFeatures : int {
...@@ -39,7 +39,7 @@ protected: ...@@ -39,7 +39,7 @@ protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
// sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
}); });
}); });
} }
...@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features.push_back(FwdFeatures::TOPK_LENGTH); required_features.push_back(FwdFeatures::TOPK_LENGTH);
} }
// if (is_sm90a) { if (is_sm90a) {
// Fwd_Sm90_Impl fwd_impl; Fwd_Sm90_Impl fwd_impl;
// fwd_impl.run(params, required_features); fwd_impl.run(params, required_features);
// } else if (is_sm100f) { } else {
// if (h_q == 64) { TORCH_CHECK(false, "Unsupported architecture");
// Fwd_Sm100_Head64_Impl fwd_impl; }
// fwd_impl.run(params, required_features);
// } else if (h_q == 128) {
// Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
// Fwd_Sm100_Head128_Impl regular_impl;
// bool use_small_topk_impl = false;
// if (
// (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
// !regular_impl.check_if_all_features_are_supported(required_features)
// ) {
// use_small_topk_impl = true;
// }
// if (use_small_topk_impl) {
// small_topk_impl.run(params, required_features);
// } else {
// regular_impl.run(params, required_features);
// }
// } else {
// TORCH_CHECK(false, "Unsupported h_q: ", h_q);
// }
// } else {
// TORCH_CHECK(false, "Unsupported architecture");
// }
return {out, max_logits, lse}; return {out, max_logits, lse};
} }
...@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK; ...@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK; static constexpr int D_K = D_QK;
static constexpr int D_V = 512; static constexpr int D_V = 512;
static constexpr int B_H = 64; static constexpr int kNWarps = 4;
static constexpr int B_H = 16;
static constexpr int B_TOPK = 64; // TopK block size static constexpr int B_TOPK = 64; // TopK block size
static constexpr int NUM_THREADS = 128*3; static constexpr int NUM_THREADS = kNWarps * 64;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
template<int NUM_TILES> using Element = cutlass::bfloat16_t;
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( using elem_type = Element;
GMMA::Layout_K_SW128_Atom<bf16>{}, using ElementAccum = float;
Shape<Int<B_H>, Int<64*NUM_TILES>>{}, using index_t = int64_t;
Step<_1, _2>{} static constexpr int kBlockM = B_H;
), Shape<_1, _1>{})); static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
template<int NUM_TILES> static constexpr int kHeadDimV = D_V;
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{}, using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
Shape<Int<B_H>, Int<64*NUM_TILES>>{}, // 没打开?
Step<_1, _2>{} // #if defined(__gfx936__) || defined(__gfx938__) || 1
), Shape<_1, _1>{})); // using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
template<int NUM_TILES> // MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( // MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{}, // >;
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{}, using MMA_Atom_Arch = std::conditional_t<
Step<_1, _2>{} std::is_same_v<elem_type, cutlass::half_t>,
), Shape<_1, _1>{})); MMA_Atom<GFX928_16x16x32_F32F16F16F32_NN>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NN>
template<int NUM_TILES> >;
using SmemLayoutKTilesTransposed = decltype(composition( using TiledMma = TiledMMA<
SmemLayoutKTiles<NUM_TILES>{}, MMA_Atom_Arch,
Layout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{} Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
)); ValLayoutMNK>;
// #endif
using SmemLayoutQ = SmemLayoutQTiles<D_Q/64>;
using SmemLayoutO = SmemLayoutOTiles<D_V/64>; using MMA_Atom_Arch_16x32 = std::conditional_t<
using SmemLayoutK = SmemLayoutKTiles<D_Q/64>; std::is_same_v<elem_type, cutlass::half_t>,
using SmemLayoutV = SmemLayoutKTilesTransposed<D_V/64>; MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<D_V/64/2>; MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using SmemLayoutS = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{}, using TiledMma_O = TiledMMA<
Shape<Int<B_H>, Int<B_TOPK>>{} MMA_Atom_Arch_16x32,
), Shape<_1, _1>{})); Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using SmemLayoutAtomQ =
Layout<Shape<Int<16>, Int<32>>, Stride<Int<32>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<16 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutK_place_holder = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<4 * 32>>{}));
struct SharedMemoryPlan { struct SharedMemoryPlan {
union { union {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q; struct {
array_aligned<bf16, cosize_v<SmemLayoutO>> o; cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
} q_o; };
array_aligned<bf16, cosize_v<SmemLayoutK>> k[2]; struct {
array_aligned<bf16, cosize_v<SmemLayoutS>> s[D_QK == 576 ? 1 : 2]; // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_place_holder; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
bool is_kv_valid[2][B_TOPK]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
float2 sM[32]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
float2 sL[64]; // For reduction across WG0/1 in epilogue };
float final_max_logits[64], final_lse[64]; struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; // transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
}; };
......
#pragma once #pragma once
#include "config.h" #include "config.h"
#include "utils.h" #include "utils.h"
#include "softmax.h"
#include "../../helpers.h" #include "../../helpers.h"
namespace sm90::fwd { namespace sm90::fwd {
...@@ -11,6 +11,433 @@ using namespace cute; ...@@ -11,6 +11,433 @@ using namespace cute;
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params) { __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params) {
extern __shared__ char smem_[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(smem_);
const int tidx = threadIdx.x;
static constexpr int kBlockM = B_H;
static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
static constexpr int kHeadDimV = D_V;
const int warp_idx = tidx / 64;
const int s_q_idx = blockIdx.x;
const int bidh = blockIdx.y;
const int lane_idx = tidx % 64;
const index_t row_offset_q = s_q_idx * params.stride_q_s_q + bidh * kBlockM * params.stride_q_h_q;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.stride_q_h_q, _1{}));
const index_t row_offset_k = 0 * params.stride_kv_h_kv;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.kv) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.stride_kv_s_kv, _1{}));
const index_t row_offset_topk = s_q_idx * params.stride_indices_s_q;
int* gIndices = reinterpret_cast<int *>(params.indices) + row_offset_topk;
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), SmemLayoutQ{});
Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), SmemLayoutV{});
Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), SmemLayoutK{});
Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), SmemLayoutP{});
Tensor sVt = make_tensor(sV.data(), SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), SmemLayoutVtransposedNoSwizzle{});
Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), SmemLayoutRow{});
Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), SmemLayoutRow{});
TiledMMA tiled_mma = TiledMma{};
auto thr_mma = tiled_mma.get_thread_slice(tidx);
TiledMMA tiled_mma_o = TiledMma_O{};
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 0, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 1, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 2, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, true, true>(gQ, sQ, 3, params.stride_q_h_q, params.h_q - bidh * kBlockM);
flash::lds_direct_copy<false, false, true>(gQ, sQ, 4, params.stride_q_h_q, params.h_q - bidh * kBlockM);
auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ);
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11));
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15));
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17));
__syncthreads();
const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);
auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
Tensor tSrK = thr_mma.partition_fragment_B(sK);
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
Tensor tSrK_smem = thr_mma.partition_fragment_B(gK);
auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom<GFX928_DS_READ_DS_M32x16_B16, Element>{}, tiled_mma_o);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle);
Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
clear(acc_o);
flash::Softmax<size<1>(acc_o)> softmax;
auto calc_row_and_col = [&](const int block_idx) -> std::tuple<int, int> {
// 计算swizzle后的全局显存访存地址
int virtual_row = lane_idx / 8;
int virtual_col = lane_idx % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane_idx / 4;
row = (row >= 8 ) ^ row;
int col = swizzle_col % 4;
int warp_id = tidx / 64;
int row_offset = block_idx * kBlockN + row + (warp_idx * 16) ;
// row_offset = row_offset < params.topk ? gIndices[row_offset] : -1;
row_offset = gIndices[row_offset];
return {row_offset, col};
};
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx++)
{
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s);
auto [row_offset, col] = calc_row_and_col(block_idx);
if constexpr (D_QK == 576)
{
for (int i = 16; i < 18; i++)
{
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv);
}
asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0 + 16), tSrK(_, _, 0), acc_s);
asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1 + 16), tSrK(_, _, 1), acc_s);
}
else
{
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv);
}
for (int i = 1; i < 4; i++) {
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv);
}
int k_idx = 0;
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 0>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 4, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 5, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 6, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 7, params.stride_kv_s_kv, params.s_kv);
k_idx++;
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 8, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 9, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 10, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 11, params.stride_kv_s_kv, params.s_kv);
k_idx++;
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 2>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 12, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 13, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 14, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 15, params.stride_kv_s_kv, params.s_kv);
k_idx++;
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++;
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm volatile("s_barrier\n\t");
// if (block0())
// {
// printf(" %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2));
// }
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
auto is_valid_token = [&](const int idx) -> bool {
int offs = int(get<1>(tScS(idx))) + block_idx * kBlockN;
int t = gIndices[offs];
bool is_cur_token_valid = t >= 0 && t < params.s_kv;
if constexpr (HAVE_TOPK_LENGTH) {
is_cur_token_valid = is_cur_token_valid && (offs < topk_length);
}
return is_cur_token_valid;
};
{
for (int i = 0; i < size(acc_s); ++i) {
// idx = idx < params.topk ? gIndices[idx] : -1;
if (!is_valid_token(i)) acc_s(i) = -INFINITY;
}
}
block_idx == 0 ?
softmax.template softmax_rescale_o_prefill</*Is_first=*/true, /*Check_inf=*//*Is_local=*/false>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2):
softmax.template softmax_rescale_o_prefill</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2);
// if (block0())
// {
// printf(" %.2f %.2f %.2f %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2), acc_s(3), softmax.row_max(0), params.sm_scale_div_log2);
// }
Tensor rP = flash::convert_type<Element>(acc_s);
Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP);
{
flash::__ds_read_m32x16_row_col_rrow<0, 0, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o);
flash::__ds_read_m32x16_row_col_rrow<0, 1, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow<0, 2, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o);
flash::__ds_read_m32x16_row_col_rrow<0, 3, 3>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o);
// for (int i = 0; i < size(tOrP); i++)
// {
// tOrP(i) = Element(1.0f);
// }
// cute::copy(smem_tiled_copy_V, tOsVt(_, 0, 0), tOrVt_copy_view(_, 0, 0));
// for (int i = 0; i < 4; i++) {
// cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
// // if (tOrVt(_, _, i) )
// cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
// }
// for (int i = 0; i < 8 * 2 * 16; i++)
// {
// }
// asm volatile("s_barrier");
// if (thread0()) {
// for (int i = 0; i < 64; i++) {
// for (int j = 0; j < 512; j++) {
// printf(" %.2f ", float(sK(i, j)));
// }
// printf("\n");
// }
// }
// if (block0())
// {
// print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n",
// tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3),
// acc_o(4), acc_o(5), acc_o(6), acc_o(7),
// acc_o(8), acc_o(9), acc_o(10), acc_o(11),
// acc_o(12), acc_o(13), acc_o(14), acc_o(15)
// );
// }
}
// asm volatile("s_barrier\n\t");
}
Tensor lse = softmax.template normalize_softmax_lse_prefill<false>(acc_o, sRow_sum_reduce_buffer, params.sm_scale);
const index_t row_offset_o = s_q_idx * params.h_q * params.d_v + bidh * kBlockM * params.d_v;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.out) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.d_v, _1{}));
// lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
const index_t row_offset_lse = s_q_idx * params.h_q + bidh * kBlockM;
float* gLSE = reinterpret_cast<float *>(params.lse) + row_offset_lse;
// const index_t row_offset_lse = m_block * params.h_q;
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + start_head_idx + lane_idx % 16);
if (flash::is_positive_infinity(rAttn_sink))
{
for (int i = 0; i < size(acc_o); i++)
{
acc_o(i) = 0.0f;
}
}
else
{
if (!flash::is_positive_infinity(lse(0)))
{
float lse_exp2 = __builtin_amdgcn_exp2f(lse[0] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
float o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
for (int i = 0; i < size(acc_o); i++)
{
acc_o(i) *= o_scale;
}
}
}
}
// if (block0())
// {
// print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n",
// tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3),
// acc_o(4), acc_o(5), acc_o(6), acc_o(7),
// acc_o(8), acc_o(9), acc_o(10), acc_o(11),
// acc_o(12), acc_o(13), acc_o(14), acc_o(15)
// );
// }
{
// store O and gLSE
auto rO = flash::convert_type<Element>(acc_o);
int row, col;
const int warpId = tidx / 64;
const int laneId = tidx % 64;
for (int mi = 0; mi < size<1>(acc_o); ++mi) {
row = mi * kBlockM + laneId % 16;
if (row < params.h_q) {
for (int ni = 0; ni < size<2>(acc_o); ++ni) {
col = (laneId / 16) + ni * 128 + warpId * 32 ;
for (int ei = 0; ei < size<0>(acc_o); ++ei) {
gO(row, col) = rO(ei, mi, ni);
col += 4;
}
}
gLSE[row] = lse(mi);
gMax_logits[row] = softmax.row_max(mi) * params.sm_scale;
}
}
}
} }
...@@ -26,9 +453,11 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -26,9 +453,11 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0); KU_ASSERT(params.topk > 0);
KU_ASSERT(params.h_q % B_H == 0); KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>;
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); constexpr size_t smem_size = 65536;
dim3 grid((params.s_q, params.h_q/B_H), 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
} }
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
......
...@@ -376,7 +376,7 @@ struct Softmax { ...@@ -376,7 +376,7 @@ struct Softmax {
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi); float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __log2f(sum); lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
......
...@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val) ...@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
return fp32.as_bits == inf_tmp.as_bits; return fp32.as_bits == inf_tmp.as_bits;
} }
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
#if defined(__gfx936__) || defined(__gfx938__)
{
if constexpr (Is_load_Q) {
// // 32x64
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 16*128;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;
int virtual_col = lane % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
}
#endif
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout
// class IdxEngine, class IdxLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_prefill_sparse_mla(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int row_offset,
int col,
int k_idx_, const int row_stride, int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
glob_ptr.latter |= ((row_stride * 2) << 16); // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = max_MN;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (col_offset) * element_size; // bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 4) * mma_k * element_size;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset == -1 ? max_MN : row_offset;
index_offset[1] = offset_v;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment