Commit b894e2da authored by zhanghj2's avatar zhanghj2
Browse files

优化mtp场景

parent 7efb944d
......@@ -271,13 +271,12 @@ get_mla_decoding_metadata_dense_fp8(
// This should match the logic in the MLA kernel.
int block_size_m = 16;
static constexpr int block_size_n = 64;
if (h_q.has_value()) {
if (h_q.value() >= 64) {
if (num_heads_per_head_k > 32) {
block_size_m = 64;
} else if (h_q.value() > 16) {
} else if (num_heads_per_head_k > 16) {
block_size_m = 32;
}
}
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
......
......@@ -4314,7 +4314,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, const std::string& kv
if (kv_cache_dtype == "auto") {
// printf(" seqlen_q %d \n", params.seqlen_q);
if (params.ngroups >= 64) {
if (params.seqlen_q > 32) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_tp1<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla_tp1<Kernel_traits, flash::SharedStorageMLATP1<Kernel_traits>, Fp8KVCacheDataType::kAuto>(params, stream);
} else {
......@@ -4325,7 +4325,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, const std::string& kv
using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, T, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLAFp8<Kernel_traits>, Fp8KVCacheDataType::kFp8E4M3>(params, stream);
} else if (kv_cache_dtype == "fp8_e5m2") {
if (params.ngroups >= 64) {
if (params.seqlen_q > 32) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8_TP1<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla_tp1<Kernel_traits, flash::SharedStorageMLAFp8_TP1<Kernel_traits>, Fp8KVCacheDataType::kFp8E5M2>(params, stream);
} else {
......
......@@ -2781,10 +2781,10 @@ void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params &params, cudaStream_t str
return;
}
if constexpr (std::is_same_v<T, cutlass::float_e4m3_t>) {
if (params.ngroups >= 64) {
if (params.seqlen_q > 32) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP1<576, 64, 64, 8, T, To, 512>;
run_flash_splitkv_fwd_mla_fp8_tp1<Kernel_traits, flash::SharedStorageMLAFloat8_TP1<Kernel_traits>>(params, stream);
} else if (params.ngroups > 16) {
} else if (params.seqlen_q > 16) {
using Kernel_traits = Flash_fwd_kernel_traits_mla_qkvfp8_TP4<576, 32, 64, 4, T, To, 512>;
run_flash_splitkv_fwd_mla_fp8_tp4<Kernel_traits, flash::SharedStorageMLAFloat8_TP4<Kernel_traits>>(params, stream);
} else {
......
......@@ -206,6 +206,12 @@ def main(torch_dtype, is_prof=False):
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1, 2, 3, 4]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
......@@ -227,18 +227,14 @@ def main(torch_dtype, is_prof=False):
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# for b in [1]:
# for s in [128]:
# for h_q in [128]:
# for s_q in [2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1,2,3,4]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# '''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment