Commit 929ccc23 authored by shenzhe's avatar shenzhe
Browse files

Refine DSA MLS prefill and BF16 sparse decode dispatch

parent def4e3a9
...@@ -29,6 +29,10 @@ class FwdImplBase : public ImplBase< ...@@ -29,6 +29,10 @@ class FwdImplBase : public ImplBase<
> {}; > {};
class Fwd_Sm90_Impl : public FwdImplBase { class Fwd_Sm90_Impl : public FwdImplBase {
public:
explicit Fwd_Sm90_Impl(bool enable_dsa_mls_prefill)
: enable_dsa_mls_prefill_(enable_dsa_mls_prefill) {}
DECLARE_SUPPORTED_FEATURES( DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_16, FwdFeatures::HEAD_16,
FwdFeatures::HEAD_64, FwdFeatures::HEAD_64,
...@@ -42,8 +46,11 @@ class Fwd_Sm90_Impl : public FwdImplBase { ...@@ -42,8 +46,11 @@ class Fwd_Sm90_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
if ((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) || const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr;
gfx93::fwd::dsa_mls::should_run(params)) { if (enable_dsa_mls_prefill_ &&
!disable_dsa_mls_prefill &&
((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) ||
gfx93::fwd::dsa_mls::should_run(params))) {
gfx93::fwd::dsa_mls::run(params); gfx93::fwd::dsa_mls::run(params);
return; return;
} }
...@@ -54,6 +61,9 @@ protected: ...@@ -54,6 +61,9 @@ protected:
}); });
}); });
} }
private:
bool enable_dsa_mls_prefill_;
}; };
static std::vector<at::Tensor> sparse_attn_prefill_interface( static std::vector<at::Tensor> sparse_attn_prefill_interface(
...@@ -182,7 +192,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -182,7 +192,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
} }
if (arch.is_gfx93x()) { if (arch.is_gfx93x()) {
Fwd_Sm90_Impl fwd_impl; Fwd_Sm90_Impl fwd_impl(arch.is_gfx938());
fwd_impl.run(params, required_features); fwd_impl.run(params, required_features);
} else { } else {
TORCH_CHECK(false, "Unsupported architecture"); TORCH_CHECK(false, "Unsupported architecture");
......
...@@ -31,6 +31,11 @@ struct LocalArch { ...@@ -31,6 +31,11 @@ struct LocalArch {
const auto base = arch_name.substr(0, arch_name.find(':')); const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx936" || base == "gfx938"; return base == "gfx936" || base == "gfx938";
} }
bool is_gfx938() const {
const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx938";
}
}; };
static int int64_stride_to_int(int64_t stride) { static int int64_stride_to_int(int64_t stride) {
...@@ -38,13 +43,25 @@ static int int64_stride_to_int(int64_t stride) { ...@@ -38,13 +43,25 @@ static int int64_stride_to_int(int64_t stride) {
return static_cast<int>(stride); return static_cast<int>(stride);
} }
static int default_num_splits(int topk, int extra_topk) { static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
if (extra_topk > 0) { if (extra_topk > 0) {
return 2; return 2;
} }
if (topk == 1024) return 16;
if (topk == 512) return 8; int split = 1;
return 1; if (topk > 1024) {
split = 32;
} else if (topk == 1024) {
split = 16;
} else if (topk == 512) {
split = 8;
}
constexpr int64_t kMaxDecodeTasksBeforeReducingSplit = 2048;
while (split > 1 && static_cast<int64_t>(b) * s_q * split > kMaxDecodeTasksBeforeReducingSplit) {
split /= 2;
}
return split;
} }
static void check_optional_extra( static void check_optional_extra(
...@@ -74,7 +91,7 @@ run( ...@@ -74,7 +91,7 @@ run(
int d_v, int d_v,
float sm_scale) { float sm_scale) {
LocalArch arch; LocalArch arch;
TORCH_CHECK(arch.is_gfx93x(), "DSA BF16 sparse decode is only supported on gfx936/gfx938"); TORCH_CHECK(arch.is_gfx938(), "DSA BF16 sparse decode is only supported on gfx938");
KU_CHECK_NDIM(q, 4); KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4); KU_CHECK_NDIM(kv, 4);
...@@ -97,17 +114,13 @@ run( ...@@ -97,17 +114,13 @@ run(
TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode"); TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode");
TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1"); TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1");
TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128"); TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128");
TORCH_CHECK(d_qk == 512 || d_qk == 576, "DSA BF16 sparse decode only supports d_qk == 512 or 576"); TORCH_CHECK(d_qk == 512, "DSA BF16 sparse decode only supports d_qk == 512 for now");
TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512"); TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512");
TORCH_CHECK(topk > 0, "topk must be positive"); TORCH_CHECK(topk > 0, "topk must be positive");
if (has_extra) { if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode supports extra_topk <= 1024");
TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive"); TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive");
TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv"); TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv");
TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk"); TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode supports topk <= 1024");
} }
check_optional_extra(extra_kv, extra_indices, extra_topk_length); check_optional_extra(extra_kv, extra_indices, extra_topk_length);
...@@ -167,7 +180,7 @@ run( ...@@ -167,7 +180,7 @@ run(
at::Tensor scores_sum = scores_memory.select(0, 1); at::Tensor scores_sum = scores_memory.select(0, 1);
if (!num_splits.has_value()) { if (!num_splits.has_value()) {
const int split = default_num_splits(topk, extra_topk); const int split = default_num_splits(b, s_q, topk, extra_topk);
num_splits = torch::empty({1}, opts.dtype(torch::kInt32)); num_splits = torch::empty({1}, opts.dtype(torch::kInt32));
num_splits->fill_(split); num_splits->fill_(split);
} }
...@@ -177,6 +190,14 @@ run( ...@@ -177,6 +190,14 @@ run(
TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor"); TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor");
const int requested_num_splits = num_splits->item<int>(); const int requested_num_splits = num_splits->item<int>();
TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64"); TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64");
if (requested_num_splits == 1) {
if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv and num_splits == 1 supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode with extra_kv and num_splits == 1 supports extra_topk <= 1024");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode with num_splits == 1 supports topk <= 1024");
}
}
Flash_fwd_mla_params_dsa params; Flash_fwd_mla_params_dsa params;
std::memset(&params, 0, sizeof(params)); std::memset(&params, 0, sizeof(params));
...@@ -246,6 +267,11 @@ run( ...@@ -246,6 +267,11 @@ run(
params.seqlenq_ngroups_swapped = true; params.seqlenq_ngroups_swapped = true;
params.is_seqlens_k_cumulative = false; params.is_seqlens_k_cumulative = false;
params.splitkv_use_fp32_as_accum = false; params.splitkv_use_fp32_as_accum = false;
constexpr int64_t kBufferLoadPaddedTokenLimit = 32LL * 64 * 1024;
const int64_t padded_k_tokens = static_cast<int64_t>(kv.size(0)) * page_block_size;
const int64_t padded_extra_k_tokens = has_extra ? static_cast<int64_t>(extra_kv->size(0)) * extra_kv->size(1) : 0;
params.decode_use_c_load = padded_k_tokens > kBufferLoadPaddedTokenLimit ||
padded_extra_k_tokens > kBufferLoadPaddedTokenLimit;
params.num_splits = requested_num_splits; params.num_splits = requested_num_splits;
params.partition_size = topk + params.extra_topk; params.partition_size = topk + params.extra_topk;
if (params.num_splits > 1) { if (params.num_splits > 1) {
...@@ -263,11 +289,7 @@ run( ...@@ -263,11 +289,7 @@ run(
} }
hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream()); hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream());
if (d_qk == 512) {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream); gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream);
} else {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(params, stream);
}
return {out, lse, tile_scheduler_metadata, num_splits}; return {out, lse, tile_scheduler_metadata, num_splits};
} }
......
...@@ -72,7 +72,7 @@ void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) { ...@@ -72,7 +72,7 @@ void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) {
} }
template<typename T, int Headdim, int HeaddimV> template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) { static inline void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64; constexpr int kBlockM = 64;
constexpr int kBlockN = 64; constexpr int kBlockN = 64;
constexpr int WARP_M = 16; constexpr int WARP_M = 16;
...@@ -100,13 +100,15 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -100,13 +100,15 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] { BOOL_SWITCH(has_extra, Has_extra, [&] {
BOOL_SWITCH(params.decode_use_c_load, DecodeCLoad, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64< flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal, Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa> IsEvenMNConst, true, false, Is_MTP, 0, DecodeCLoad, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}); });
}); });
}); });
});
} else if (params.num_splits != 0) { } else if (params.num_splits != 0) {
dimGrid.y = params.num_splits; dimGrid.y = params.num_splits;
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
...@@ -124,10 +126,25 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -124,10 +126,25 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) { if (params.topk == 2048) {
constexpr bool CanUseFastTopk2048 = Headdim == 576 && HeaddimV == 512;
if constexpr (CanUseFastTopk2048) {
if (params.seqlen_k < params.topk) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, true, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, false, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64< flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal, Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa> IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
} else { } else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024< flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal, Kernel_traits, true, Is_dropout, false, Is_causal,
......
...@@ -27,7 +27,9 @@ bool should_run(const SparseAttnFwdParams& params) { ...@@ -27,7 +27,9 @@ bool should_run(const SparseAttnFwdParams& params) {
return true; return true;
} }
if (params.d_qk == 576 && params.h_q == 64 && params.topk == 2048 && params.s_kv >= 32768) { if (params.d_qk == 576 && params.topk == 2048 &&
((params.h_q == 64 && params.s_kv >= 24576) ||
(params.h_q == 128 && params.s_kv >= 8192))) {
return true; return true;
} }
......
...@@ -429,6 +429,7 @@ struct Flash_fwd_mla_params_dsa { ...@@ -429,6 +429,7 @@ struct Flash_fwd_mla_params_dsa {
bool seqlenq_ngroups_swapped; bool seqlenq_ngroups_swapped;
bool is_seqlens_k_cumulative; bool is_seqlens_k_cumulative;
bool splitkv_use_fp32_as_accum; bool splitkv_use_fp32_as_accum;
bool decode_use_c_load;
// not used params // not used params
float *__restrict__ scales_q_ptr; float *__restrict__ scales_q_ptr;
......
...@@ -562,7 +562,7 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params ...@@ -562,7 +562,7 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] { BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa> flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, true, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}); });
}); });
......
...@@ -55,7 +55,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -55,7 +55,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
} }
// max-rescale coefficient x sum-rescale coefficient // max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp; lds[tx] = s_sum_tmp > 0.f ? s_sum_load_ori * s_max_ratio / s_sum_tmp : 0.f;
// finally, do rescale for each split and reduce the sum of them // finally, do rescale for each split and reduce the sum of them
// each block(1waves) process (num_splits x head_dim) elements in total // each block(1waves) process (num_splits x head_dim) elements in total
...@@ -76,7 +76,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -76,7 +76,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
for (int i = 0; i < num_splits; ++i) { for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[i]; float s_scale = lds[i];
bool within_splits = (i < true_num_splits); bool within_splits = (i < true_num_splits) && (s_scale > 0.f);
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
// read ultimate scale value for current split // read ultimate scale value for current split
...@@ -197,7 +197,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -197,7 +197,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
s_sum_tmp = lds[LDS_ACCUM]; s_sum_tmp = lds[LDS_ACCUM];
// max-rescale coefficient x sum-rescale coefficient // max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp; lds[tx] = s_sum_tmp > 0.f ? s_sum_load_ori * s_max_ratio / s_sum_tmp : 0.f;
// finally, do rescale for each split and reduce the sum of them // finally, do rescale for each split and reduce the sum of them
// each block(multiple waves) process (num_splits x head_dim) elements in total // each block(multiple waves) process (num_splits x head_dim) elements in total
...@@ -224,7 +224,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -224,7 +224,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
for (int i = 0; i < split_count_this_wave; ++i) { for (int i = 0; i < split_count_this_wave; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[begin + i]; float s_scale = lds[begin + i];
bool within_splits = (begin + i) < true_num_splits; bool within_splits = ((begin + i) < true_num_splits) && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
...@@ -402,13 +402,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -402,13 +402,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
bool has_valid_lse = (lse_max_local != -INFINITY);
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = has_valid_lse ? __expf(lse_local - lse_max_local) : 0.f;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = has_valid_lse ? __logf(lse_local_logsum) + lse_max_local : -INFINITY;
// store softmax_lse // store softmax_lse
if (tx == 0) { if (tx == 0) {
...@@ -416,14 +417,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -416,14 +417,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
} }
// store rescale coefficient into lds // store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum); lds[tx] = has_valid_lse ? __expf(lse_local - lse_local_logsum) : 0.f;
// num_splits may not be 64, and thus need boundary judgement // num_splits may not be 64, and thus need boundary judgement
#pragma unroll #pragma unroll
for (int i = 0; i < num_splits; ++i) { for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[i]; float s_scale = lds[i];
bool within_splits = (i < true_num_splits); bool within_splits = (i < true_num_splits) && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
...@@ -555,13 +556,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -555,13 +556,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
bool has_valid_lse = (lse_max_local != -INFINITY);
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = has_valid_lse ? __expf(lse_local - lse_max_local) : 0.f;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = has_valid_lse ? __logf(lse_local_logsum) + lse_max_local : -INFINITY;
float attn_sink_o_scale = 1.0f; float attn_sink_o_scale = 1.0f;
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
...@@ -578,7 +580,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -578,7 +580,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
} }
// store rescale coefficient into lds // store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum) * attn_sink_o_scale; lds[tx] = has_valid_lse ? __expf(lse_local - lse_local_logsum) * attn_sink_o_scale : 0.f;
// num_splits may not be 64, and thus need boundary judgement // num_splits may not be 64, and thus need boundary judgement
#pragma unroll #pragma unroll
...@@ -586,6 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -586,6 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// read ultimate scale value for current split // read ultimate scale value for current split
bool within_splits = ((i + wave_id) < true_num_splits); bool within_splits = ((i + wave_id) < true_num_splits);
float s_scale = num_splits >= WARP_NUM ? lds[i + wave_id]: (within_splits ? lds[i + wave_id]: 0.f); float s_scale = num_splits >= WARP_NUM ? lds[i + wave_id]: (within_splits ? lds[i + wave_id]: 0.f);
within_splits = within_splits && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
......
#include "fwd.h" #include "fwd.h"
#include <cstdlib>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <hip/hip_runtime.h>
#include "dsa_mls/fwd.h" #include "dsa_mls/fwd.h"
#include "phase1.h" #include "phase1.h"
namespace gfx93 { namespace gfx93 {
namespace {
bool is_current_device_gfx938() {
int device = 0;
hipDeviceProp_t prop{};
if (hipGetDevice(&device) != hipSuccess || hipGetDeviceProperties(&prop, device) != hipSuccess) {
return false;
}
const std::string arch_name = prop.gcnArchName;
return arch_name.substr(0, arch_name.find(':')) == "gfx938";
}
} // namespace
void run_fwd_kernel(const SparseAttnFwdParams& params) { void run_fwd_kernel(const SparseAttnFwdParams& params) {
if (gfx93::fwd::dsa_mls::should_run(params)) { const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr;
const bool enable_dsa_mls_prefill = is_current_device_gfx938();
if (enable_dsa_mls_prefill && !disable_dsa_mls_prefill && gfx93::fwd::dsa_mls::should_run(params)) {
gfx93::fwd::dsa_mls::run(params); gfx93::fwd::dsa_mls::run(params);
return; return;
} }
......
...@@ -57,7 +57,7 @@ def flash_mla_with_kvcache( ...@@ -57,7 +57,7 @@ def flash_mla_with_kvcache(
cache_seqlens: Optional[torch.Tensor], cache_seqlens: Optional[torch.Tensor],
head_dim_v: int, head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta, tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None, num_splits: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
is_fp8_kvcache: bool = False, is_fp8_kvcache: bool = False,
...@@ -78,7 +78,7 @@ def flash_mla_with_kvcache( ...@@ -78,7 +78,7 @@ def flash_mla_with_kvcache(
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512 head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface). num_splits: optional override for BF16 sparse decode. Other paths keep using sched_meta.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool. is_fp8_kvcache: bool.
...@@ -104,7 +104,6 @@ def flash_mla_with_kvcache( ...@@ -104,7 +104,6 @@ def flash_mla_with_kvcache(
sched_meta = tile_scheduler_metadata sched_meta = tile_scheduler_metadata
indices_in_kvcache = indices indices_in_kvcache = indices
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
...@@ -155,14 +154,18 @@ def flash_mla_with_kvcache( ...@@ -155,14 +154,18 @@ def flash_mla_with_kvcache(
assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False" assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
if extra_k_cache is not None: if extra_k_cache is not None:
assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False" assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
else:
assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
decode_num_splits = num_splits if num_splits is not None else sched_meta.num_splits
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink, q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits, sched_meta.tile_scheduler_metadata, decode_num_splits,
extra_k_cache, extra_indices_in_kvcache, extra_topk_length, extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
head_dim_v, softmax_scale head_dim_v, softmax_scale
) )
else: else:
# Dense attention # Dense attention
assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
......
version = '1.2.0'
git_hash = 'ac8223a'
git_branch = 'master-aicc'
abi = 'abi1'
dtk = '2604'
torch_version = '2.5'
hcu_version = '1.2.0+das.opt1.dtk2604'
...@@ -54,11 +54,9 @@ def is_bf16_decode_supported_param(t: TestParam) -> bool: ...@@ -54,11 +54,9 @@ def is_bf16_decode_supported_param(t: TestParam) -> bool:
return False return False
if t.h_q not in [64, 128]: if t.h_q not in [64, 128]:
return False return False
if t.d_qk not in [512, 576]: if t.d_qk != 512:
return False return False
if t.decode.extra_topk is None: return True
return t.topk <= 1024
return t.topk <= 256 and t.decode.extra_topk <= 1024
@dataclasses.dataclass @dataclasses.dataclass
class RawTestParamForDecode: class RawTestParamForDecode:
......
...@@ -110,6 +110,11 @@ def gen_testcase() -> List[RawTestParam]: ...@@ -110,6 +110,11 @@ def gen_testcase() -> List[RawTestParam]:
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG4 # MODEL1 CONFIG4
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# DSA BF16 large topk / extra_topk coverage
(RawTestParam(0, 64, 2, 1, 32768, True, topk=4096, d_qk=512, block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 128, 2, 1, 32768, True, topk=8192, d_qk=576, block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 64, 2, 1, 32768, True, topk=128, d_qk=512, extra_s_k=32768, extra_topk=4096, block_size=256, extra_block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 128, 2, 1, 32768, True, topk=128, d_qk=512, extra_s_k=32768, extra_topk=8192, block_size=256, extra_block_size=256, have_extra_topk_length=True, check_correctness=True, num_runs=0), [1, 2]),
] ]
performance_cases = [ performance_cases = [
# Production cases # Production cases
...@@ -173,10 +178,13 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -173,10 +178,13 @@ def test_flash_mla(p: TestParam) -> Result:
result = kk.bench_kineto(run_decode, p.num_runs) result = kk.bench_kineto(run_decode, p.num_runs)
if lib.is_decode_bf16_kvcache(): if lib.is_decode_bf16_kvcache():
splitkv_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv" main_kernel_names = [
"flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv",
"flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<",
]
combine_kernel_name = "flash_mla_splitkv_reduce_kernel" combine_kernel_name = "flash_mla_splitkv_reduce_kernel"
else: else:
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" main_kernel_names = ["flash_fwd_splitkv_mla_fp8_sparse_kernel"]
combine_kernel_name = "flash_fwd_mla_combine_kernel" combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages # Get individual kernel time usages
...@@ -188,21 +196,24 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -188,21 +196,24 @@ def test_flash_mla(p: TestParam) -> Result:
kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6 kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6
else: else:
kernel_time_usages_us[kernel_name] = None kernel_time_usages_us[kernel_name] = None
pick_kernel_time_usage(splitkv_kernel_name) for main_kernel_name in main_kernel_names:
pick_kernel_time_usage(main_kernel_name)
pick_kernel_time_usage(combine_kernel_name) pick_kernel_time_usage(combine_kernel_name)
# Get E2E time usages # Get E2E time usages
def have_kernel(name: str): def have_kernel(name: str):
return kernel_time_usages_us[name] is not None return kernel_time_usages_us[name] is not None
active_main_kernel_name = next((name for name in main_kernel_names if have_kernel(name)), None)
if kk.is_using_profiling_tools(): if kk.is_using_profiling_tools():
e2e_time_usage_us = 1e6 e2e_time_usage_us = 1e6
else: else:
assert have_kernel(splitkv_kernel_name) assert active_main_kernel_name is not None
if have_kernel(combine_kernel_name): if have_kernel(combine_kernel_name):
e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6 e2e_time_usage_us = result.get_e2e_time(active_main_kernel_name, combine_kernel_name) * 1e6
else: else:
e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name] e2e_time_usage_us = kernel_time_usages_us[active_main_kernel_name]
assert e2e_time_usage_us is not None assert e2e_time_usage_us is not None
flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t) flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)
...@@ -216,12 +227,14 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -216,12 +227,14 @@ def test_flash_mla(p: TestParam) -> Result:
print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us') print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')
print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}') print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')
print(f'Time (per): {e2e_time_usage_us:.1f} us') print(f'Time (per): {e2e_time_usage_us:.1f} us')
print_kernel_time_usage(splitkv_kernel_name, "Splitkv") for main_kernel_name in main_kernel_names:
print_kernel_time_usage(main_kernel_name, "Decode")
print_kernel_time_usage(combine_kernel_name, "Combine") print_kernel_time_usage(combine_kernel_name, "Combine")
print(f'TFlops: {achieved_tflops:.1f}') print(f'TFlops: {achieved_tflops:.1f}')
print(f'GB/s: {achieved_gBps:.0f}') print(f'GB/s: {achieved_gBps:.0f}')
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps) main_kernel_time_usage_us = kernel_time_usages_us[active_main_kernel_name] if active_main_kernel_name is not None else 0.0
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, main_kernel_time_usage_us or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)
is_correct = True is_correct = True
if p.check_correctness: if p.check_correctness:
...@@ -229,10 +242,11 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -229,10 +242,11 @@ def test_flash_mla(p: TestParam) -> Result:
with torch.profiler.record_function("reference_flash_mla"): with torch.profiler.record_function("reference_flash_mla"):
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
if lib.is_decode_bf16_kvcache(): if lib.is_decode_bf16_kvcache():
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=6e-6)
is_correct &= is_out_correct is_correct &= is_out_correct
else: else:
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct is_correct &= is_out_correct and is_lse_correct
...@@ -261,10 +275,10 @@ def main(): ...@@ -261,10 +275,10 @@ def main():
bf16_testcases = [] bf16_testcases = []
seen_bf16_cases = set() seen_bf16_cases = set()
for t in testcases: for t in testcases:
if t.d_qk == 576:
t = dataclasses.replace(t, d_qk=512)
if not lib.is_bf16_decode_supported_param(t): if not lib.is_bf16_decode_supported_param(t):
continue continue
if t.num_runs > 0 and t.decode.b > 16:
t = dataclasses.replace(t, decode=dataclasses.replace(t.decode, b=16))
key = dataclasses.asdict(t) key = dataclasses.asdict(t)
key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None
key = tuple(key.items()) key = tuple(key.items())
......
...@@ -11,6 +11,8 @@ import ref ...@@ -11,6 +11,8 @@ import ref
_counter = kk.Counter() _counter = kk.Counter()
def is_dsa_mls_prefill_case(p: TestParam) -> bool: def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if get_gcn_arch_name() != "gfx938":
return False
if p.d_v != 512: if p.d_v != 512:
return False return False
if p.d_qk not in [512, 576]: if p.d_qk not in [512, 576]:
...@@ -26,7 +28,7 @@ def is_dsa_mls_prefill_case(p: TestParam) -> bool: ...@@ -26,7 +28,7 @@ def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)): if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)):
return True return True
if p.d_qk == 576 and p.h_q == 64 and p.topk == 2048 and p.s_kv >= 32768: if p.d_qk == 576 and p.topk == 2048 and ((p.h_q == 64 and p.s_kv >= 24576) or (p.h_q == 128 and p.s_kv >= 8192)):
return True return True
return False return False
...@@ -53,9 +55,16 @@ def run_test(p: TestParam) -> bool: ...@@ -53,9 +55,16 @@ def run_test(p: TestParam) -> bool:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs) bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs)
kernel_names = bench_result.get_kernel_names() kernel_names = bench_result.get_kernel_names()
prefill_kernel_name = "sparse_attn_fwd" prefill_kernel_name_candidates = [
if not any(prefill_kernel_name in name for name in kernel_names): "sparse_attn_fwd",
prefill_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64" "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64",
"flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64",
]
prefill_kernel_name = next(
(candidate for candidate in prefill_kernel_name_candidates
if any(candidate in name for name in kernel_names)),
prefill_kernel_name_candidates[0],
)
prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name) prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name)
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
...@@ -69,6 +78,8 @@ def run_test(p: TestParam) -> bool: ...@@ -69,6 +78,8 @@ def run_test(p: TestParam) -> bool:
is_correct = True is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
# DSA MLS prefill is selected for throughput and currently only treats out as the validated contract.
# max_logits/lse can differ on boundary cases, so keep those checks on the Sugon path only.
if not is_dsa_mls_prefill_case(p): if not is_dsa_mls_prefill_case(p):
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
...@@ -177,13 +188,13 @@ if __name__ == '__main__': ...@@ -177,13 +188,13 @@ if __name__ == '__main__':
performance_case_templates = [ performance_case_templates = [
# V3.2 # V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]), (576, 128, 2048, [8192, 16384, 65536, 98304, 131072]),
(576, 64, 2048, [8192, 32768, 65536, 98304, 131072]), (576, 64, 2048, [8192, 16384, 65536, 98304, 131072]),
# MODEL1 CONFIG1 # MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]), # (512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2 # MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]), # (512, 128, 1024, [8192, 32768, 49152, 65536]),
(512, 16, 1024, [8192, 32768, 49152, 65536]), # (512, 16, 1024, [8192, 32768, 49152, 65536]),
] ]
performance_cases = [ performance_cases = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment