Commit c3cf875a authored by zhanghj2's avatar zhanghj2
Browse files

实现sparse prefill, 还有bug

parent 50e2de8d
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "params.h" #include "params.h"
// #include "sm90/prefill/sparse/phase1.h" #include "sm90/prefill/sparse/phase1.h"
enum class FwdFeatures : int { enum class FwdFeatures : int {
...@@ -39,7 +39,7 @@ protected: ...@@ -39,7 +39,7 @@ protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
// sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
}); });
}); });
} }
...@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -208,34 +208,12 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features.push_back(FwdFeatures::TOPK_LENGTH); required_features.push_back(FwdFeatures::TOPK_LENGTH);
} }
// if (is_sm90a) { if (is_sm90a) {
// Fwd_Sm90_Impl fwd_impl; Fwd_Sm90_Impl fwd_impl;
// fwd_impl.run(params, required_features); fwd_impl.run(params, required_features);
// } else if (is_sm100f) { } else {
// if (h_q == 64) { TORCH_CHECK(false, "Unsupported architecture");
// Fwd_Sm100_Head64_Impl fwd_impl; }
// fwd_impl.run(params, required_features);
// } else if (h_q == 128) {
// Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
// Fwd_Sm100_Head128_Impl regular_impl;
// bool use_small_topk_impl = false;
// if (
// (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
// !regular_impl.check_if_all_features_are_supported(required_features)
// ) {
// use_small_topk_impl = true;
// }
// if (use_small_topk_impl) {
// small_topk_impl.run(params, required_features);
// } else {
// regular_impl.run(params, required_features);
// }
// } else {
// TORCH_CHECK(false, "Unsupported h_q: ", h_q);
// }
// } else {
// TORCH_CHECK(false, "Unsupported architecture");
// }
return {out, max_logits, lse}; return {out, max_logits, lse};
} }
...@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK; ...@@ -19,61 +19,97 @@ static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK; static constexpr int D_K = D_QK;
static constexpr int D_V = 512; static constexpr int D_V = 512;
static constexpr int B_H = 64; static constexpr int kNWarps = 4;
static constexpr int B_H = 16;
static constexpr int B_TOPK = 64; // TopK block size static constexpr int B_TOPK = 64; // TopK block size
static constexpr int NUM_THREADS = 128*3; static constexpr int NUM_THREADS = kNWarps * 64;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
template<int NUM_TILES> using Element = cutlass::bfloat16_t;
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( using elem_type = Element;
GMMA::Layout_K_SW128_Atom<bf16>{}, using ElementAccum = float;
Shape<Int<B_H>, Int<64*NUM_TILES>>{}, using index_t = int64_t;
Step<_1, _2>{} static constexpr int kBlockM = B_H;
), Shape<_1, _1>{})); static constexpr int kBlockN = B_TOPK;
static constexpr int kHeadDim = D_QK;
template<int NUM_TILES> static constexpr int kHeadDimV = D_V;
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{}, using ValLayoutMNK = Layout<Shape<_1, _1, _1>>;
Shape<Int<B_H>, Int<64*NUM_TILES>>{}, // 没打开?
Step<_1, _2>{} // #if defined(__gfx936__) || defined(__gfx938__) || 1
), Shape<_1, _1>{})); // using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
template<int NUM_TILES> // MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( // MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{}, // >;
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{}, using MMA_Atom_Arch = std::conditional_t<
Step<_1, _2>{} std::is_same_v<elem_type, cutlass::half_t>,
), Shape<_1, _1>{})); MMA_Atom<GFX928_16x16x32_F32F16F16F32_NN>,
MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NN>
template<int NUM_TILES> >;
using SmemLayoutKTilesTransposed = decltype(composition( using TiledMma = TiledMMA<
SmemLayoutKTiles<NUM_TILES>{}, MMA_Atom_Arch,
Layout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{} Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
)); ValLayoutMNK>;
// #endif
using SmemLayoutQ = SmemLayoutQTiles<D_Q/64>;
using SmemLayoutO = SmemLayoutOTiles<D_V/64>; using MMA_Atom_Arch_16x32 = std::conditional_t<
using SmemLayoutK = SmemLayoutKTiles<D_Q/64>; std::is_same_v<elem_type, cutlass::half_t>,
using SmemLayoutV = SmemLayoutKTilesTransposed<D_V/64>; MMA_Atom<GFX928_16x32x16_F32F16F16F32_NT>,
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<D_V/64/2>; MMA_Atom<GFX928_16x32x16_F32BF16BF16F32_NT>
>;
using SmemLayoutS = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{}, using TiledMma_O = TiledMMA<
Shape<Int<B_H>, Int<B_TOPK>>{} MMA_Atom_Arch_16x32,
), Shape<_1, _1>{})); Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
ValLayoutMNK>;
using SmemLayoutAtomQ =
Layout<Shape<Int<16>, Int<32>>, Stride<Int<32>, _1>>;
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutAtomK = decltype(composition(
Swizzle<3, 3, 3>{},
Layout<Shape<Int<8>, Int<32>>, Stride<Int<32>, _1>>{}));
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<16 * 32>>{}));
using SmemLayoutAtomV = SmemLayoutAtomK;
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomP = Layout<Shape<Int<4*16*16>>, Stride<Int<1>>>;
using SmemLayoutP = decltype(tile_to_shape(
SmemLayoutAtomP{},
Shape<Int<4*16*16>>{}));
using SmemLayoutRow = Layout<Shape<_128>, Stride<_1>>;
using SmemLayoutK_place_holder = decltype(tile_to_shape(
SmemLayoutAtomK{},
Shape<Int<kBlockN>, Int<4 * 32>>{}));
struct SharedMemoryPlan { struct SharedMemoryPlan {
union { union {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q; struct {
array_aligned<bf16, cosize_v<SmemLayoutO>> o; cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // Double buffer
} q_o; };
array_aligned<bf16, cosize_v<SmemLayoutK>> k[2]; struct {
array_aligned<bf16, cosize_v<SmemLayoutS>> s[D_QK == 576 ? 1 : 2]; // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers cute::array_aligned<Element, cute::cosize_v<SmemLayoutK_place_holder>> smem_place_holder; // Double buffer
cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
bool is_kv_valid[2][B_TOPK]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_sum;
float2 sM[32]; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
float2 sL[64]; // For reduction across WG0/1 in epilogue };
float final_max_logits[64], final_lse[64]; struct {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
};
};
// transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; // transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
}; };
......
This diff is collapsed.
...@@ -376,7 +376,7 @@ struct Softmax { ...@@ -376,7 +376,7 @@ struct Softmax {
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi); float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __log2f(sum); lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
......
...@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val) ...@@ -362,4 +362,202 @@ is_positive_infinity(const float& f_val)
return fp32.as_bits == inf_tmp.as_bits; return fp32.as_bits == inf_tmp.as_bits;
} }
template <
bool Is_even_MN=true,
bool Is_even_K=true,
bool Is_load_Q=false,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
lds_direct_copy(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int k_idx_, const int row_stride,
const int max_MN=0)
{
#if defined(__gfx936__) || defined(__gfx938__)
{
if constexpr (Is_load_Q) {
// // 32x64
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 16*128;
int row = lane % 16;
int col = lane / 16;
int row_offset = row ;
int col_offset = (col + warp_id * 4) * elements_per_thread + k_idx * 128;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
else {
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
// int row = lane / 4;
// int col = lane % 4;
// int swizzle_col = ((row / 2) ^ (col )) * 4 + (col % 4);
// 此处待优化,后8行,行号需要交换
int virtual_row = lane / 8;
int virtual_col = lane % 8;
int swizzle_col = virtual_row ^ virtual_col;
int row = lane / 4;
// 8->9 9->8
row = (row >= 8 ) ^ row;
// row = row >= 8 ? (swizzle_col / 4) > 0 ? row + 1 : row - 1 : row;
int col = swizzle_col % 4;
int row_offset = row + (warp_id * 16) ;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
}
#endif
}
template <bool Is_even_K=true,
bool Is_even_MN=true,
bool Use_cache_swizzle = true,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout
// class IdxEngine, class IdxLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_for_prefill_sparse_mla(
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst,
int row_offset,
int col,
int k_idx_, const int row_stride, int max_MN=0)
{
constexpr int warp_size = 64;
int tidx = threadIdx.x;
int warp_id = __builtin_amdgcn_readfirstlane(tidx / warp_size);
int lane = tidx % warp_size;
constexpr int element_size = 2;
int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
const int offset_s = 0;
// global addr
// uint32x4_t global_addr = {0};
// *(uint64_t*)&global_addr = reinterpret_cast<uint64_t>(src.data().get());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
// global_addr[2] = 0xfffffffe;
// global_addr[3] = 0x00020000;
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(src.data().get());
glob_ptr.latter |= ((row_stride * 2) << 16); // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = max_MN;
global_addr[3] = 0x00020000;
constexpr int elements_per_thread = 8;
constexpr int bytes_per_warp = warp_size * 8 * element_size;
int mma_k = 32*64;
int col_offset = col * elements_per_thread + k_idx * 32;
int offset_v = (col_offset) * element_size; // bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && (row_offset >= max_MN || row_offset < 0)) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 4) * mma_k * element_size;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
uint32x2_t index_offset = {0};
index_offset[0] = row_offset == -1 ? max_MN : row_offset;
index_offset[1] = offset_v;
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:);
}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment