Commit 0651671f authored by zhanghj2's avatar zhanghj2
Browse files

sparse decode支持head16

parent b94fdd0f
...@@ -57,6 +57,9 @@ inline int int64_stride_to_int(int64_t orig_stride) { ...@@ -57,6 +57,9 @@ inline int int64_stride_to_int(int64_t orig_stride) {
} else if (NUM_HEADS == 64) { \ } else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \ static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (NUM_HEADS <= 16) { \
static constexpr int CONSTEXPR_NAME = 16; \
return __VA_ARGS__(); \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \ TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \ } \
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
// Feature set of sparse decoding kernels // Feature set of sparse decoding kernels
enum class DecodeFeatures : int { enum class DecodeFeatures : int {
HEAD_16,
HEAD_64, HEAD_64,
HEAD_128, HEAD_128,
...@@ -41,6 +42,7 @@ public: ...@@ -41,6 +42,7 @@ public:
class Decode_Sm90_Impl : public DecodeImplBase { class Decode_Sm90_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES( DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_16,
DecodeFeatures::HEAD_64, DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_128, DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512, DecodeFeatures::HEAD_DIM_512,
...@@ -56,6 +58,13 @@ class Decode_Sm90_Impl : public DecodeImplBase { ...@@ -56,6 +58,13 @@ class Decode_Sm90_Impl : public DecodeImplBase {
public: public:
DecodeImplMeta get_meta(int h_q, int s_q) override { DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch(); Arch arch = Arch();
if (h_q <= 16) {
return {
std::max(arch.num_sms * 2 / s_q / (h_q/16), 1),
5,
64
};
}
return { return {
std::max(arch.num_sms / s_q / (h_q/64), 1), std::max(arch.num_sms / s_q / (h_q/64), 1),
5, 5,
...@@ -218,7 +227,9 @@ sparse_attn_decode_interface( ...@@ -218,7 +227,9 @@ sparse_attn_decode_interface(
} }
std::vector<DecodeFeatures> features; std::vector<DecodeFeatures> features;
if (h_q == 64) { if (h_q <= 16) {
features.push_back(DecodeFeatures::HEAD_16);
} else if (h_q == 64) {
features.push_back(DecodeFeatures::HEAD_64); features.push_back(DecodeFeatures::HEAD_64);
} else if (h_q == 128) { } else if (h_q == 128) {
features.push_back(DecodeFeatures::HEAD_128); features.push_back(DecodeFeatures::HEAD_128);
......
...@@ -100,58 +100,9 @@ struct SharedMemoryPlan { ...@@ -100,58 +100,9 @@ struct SharedMemoryPlan {
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max; cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
}; };
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
}; };
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
// array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
// array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
// } u;
// CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
// bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
// float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
// transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
}; };
// template<
// typename Shape_Q, typename TMA_Q
// >
// using TiledMMA_QK = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
static __device__ __forceinline__ void static __device__ __forceinline__ void
compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams &params, const DecodingSchedMeta& sched_meta, int batch_idx); compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams &params, const DecodingSchedMeta& sched_meta, int batch_idx);
...@@ -163,4 +114,7 @@ static void run(const SparseAttnDecodeParams &params); ...@@ -163,4 +114,7 @@ static void run(const SparseAttnDecodeParams &params);
}; };
} }
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 16>(const SparseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 16>(const SparseAttnDecodeParams &params);
}
...@@ -58,8 +58,10 @@ ext_modules.append( ...@@ -58,8 +58,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/instantiations/bf16.cu", "csrc/sm90/decode/dense/instantiations/bf16.cu",
# # sm90 sparse decode # # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
......
...@@ -27,7 +27,7 @@ def gen_testcase() -> List[RawTestParam]: ...@@ -27,7 +27,7 @@ def gen_testcase() -> List[RawTestParam]:
for have_extra_k in ([False, True] if d_qk == 512 else [False]): for have_extra_k in ([False, True] if d_qk == 512 else [False]):
for have_extra_topk_len in ([False, True] if have_extra_k else [False]): for have_extra_topk_len in ([False, True] if have_extra_k else [False]):
for have_topk_len in ([False, True] if d_qk == 512 else [False]): for have_topk_len in ([False, True] if d_qk == 512 else [False]):
for h_q in [64, 128]: for h_q in [16, 64, 128]:
cur_correctness_cases = [ cur_correctness_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
have_topk_length=have_topk_len, have_topk_length=have_topk_len,
...@@ -119,7 +119,7 @@ def gen_testcase() -> List[RawTestParam]: ...@@ -119,7 +119,7 @@ def gen_testcase() -> List[RawTestParam]:
] + [ ] + [
# Peak perf cases # Peak perf cases
RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk) RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)
for h_q in [64, 128] for h_q in [16, 64, 128]
for d_qk in [512, 576] for d_qk in [512, 576]
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment