"vscode:/vscode.git/clone" did not exist on "34489f466e8f6ddf3b7318cd1556a5df8759c005"
Commit 0651671f authored by zhanghj2's avatar zhanghj2
Browse files

sparse decode支持head16

parent b94fdd0f
......@@ -57,6 +57,9 @@ inline int int64_stride_to_int(int64_t orig_stride) {
} else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_HEADS <= 16) { \
static constexpr int CONSTEXPR_NAME = 16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \
......
......@@ -10,6 +10,7 @@
// Feature set of sparse decoding kernels
enum class DecodeFeatures : int {
HEAD_16,
HEAD_64,
HEAD_128,
......@@ -41,6 +42,7 @@ public:
class Decode_Sm90_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_16,
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
......@@ -56,6 +58,13 @@ class Decode_Sm90_Impl : public DecodeImplBase {
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
if (h_q <= 16) {
return {
std::max(arch.num_sms * 2 / s_q / (h_q/16), 1),
5,
64
};
}
return {
std::max(arch.num_sms / s_q / (h_q/64), 1),
5,
......@@ -218,7 +227,9 @@ sparse_attn_decode_interface(
}
std::vector<DecodeFeatures> features;
if (h_q == 64) {
if (h_q <= 16) {
features.push_back(DecodeFeatures::HEAD_16);
} else if (h_q == 64) {
features.push_back(DecodeFeatures::HEAD_64);
} else if (h_q == 128) {
features.push_back(DecodeFeatures::HEAD_128);
......
......@@ -100,58 +100,9 @@ struct SharedMemoryPlan {
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutRow>> smem_row_max;
};
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
};
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
// array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
// array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
// } u;
// CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
// bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
// float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
// transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
// template<
// typename Shape_Q, typename TMA_Q
// >
// using TiledMMA_QK = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
static __device__ __forceinline__ void
compute_attn_1rowblock_splitkv_sparse_mla_fp8(const SparseAttnDecodeParams &params, const DecodingSchedMeta& sched_meta, int batch_idx);
......@@ -163,4 +114,7 @@ static void run(const SparseAttnDecodeParams &params);
};
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 16>(const SparseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 16>(const SparseAttnDecodeParams &params);
}
......@@ -58,8 +58,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/instantiations/bf16.cu",
# # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
......
......@@ -27,7 +27,7 @@ def gen_testcase() -> List[RawTestParam]:
for have_extra_k in ([False, True] if d_qk == 512 else [False]):
for have_extra_topk_len in ([False, True] if have_extra_k else [False]):
for have_topk_len in ([False, True] if d_qk == 512 else [False]):
for h_q in [64, 128]:
for h_q in [16, 64, 128]:
cur_correctness_cases = [
RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk,
have_topk_length=have_topk_len,
......@@ -119,7 +119,7 @@ def gen_testcase() -> List[RawTestParam]:
] + [
# Peak perf cases
RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk)
for h_q in [64, 128]
for h_q in [16, 64, 128]
for d_qk in [512, 576]
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment