Commit 929ccc23 authored by shenzhe's avatar shenzhe
Browse files

Refine DSA MLS prefill and BF16 sparse decode dispatch

parent def4e3a9
...@@ -29,6 +29,10 @@ class FwdImplBase : public ImplBase< ...@@ -29,6 +29,10 @@ class FwdImplBase : public ImplBase<
> {}; > {};
class Fwd_Sm90_Impl : public FwdImplBase { class Fwd_Sm90_Impl : public FwdImplBase {
public:
explicit Fwd_Sm90_Impl(bool enable_dsa_mls_prefill)
: enable_dsa_mls_prefill_(enable_dsa_mls_prefill) {}
DECLARE_SUPPORTED_FEATURES( DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_16, FwdFeatures::HEAD_16,
FwdFeatures::HEAD_64, FwdFeatures::HEAD_64,
...@@ -42,8 +46,11 @@ class Fwd_Sm90_Impl : public FwdImplBase { ...@@ -42,8 +46,11 @@ class Fwd_Sm90_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
if ((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) || const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr;
gfx93::fwd::dsa_mls::should_run(params)) { if (enable_dsa_mls_prefill_ &&
!disable_dsa_mls_prefill &&
((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) ||
gfx93::fwd::dsa_mls::should_run(params))) {
gfx93::fwd::dsa_mls::run(params); gfx93::fwd::dsa_mls::run(params);
return; return;
} }
...@@ -54,6 +61,9 @@ protected: ...@@ -54,6 +61,9 @@ protected:
}); });
}); });
} }
private:
bool enable_dsa_mls_prefill_;
}; };
static std::vector<at::Tensor> sparse_attn_prefill_interface( static std::vector<at::Tensor> sparse_attn_prefill_interface(
...@@ -182,7 +192,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -182,7 +192,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
} }
if (arch.is_gfx93x()) { if (arch.is_gfx93x()) {
Fwd_Sm90_Impl fwd_impl; Fwd_Sm90_Impl fwd_impl(arch.is_gfx938());
fwd_impl.run(params, required_features); fwd_impl.run(params, required_features);
} else { } else {
TORCH_CHECK(false, "Unsupported architecture"); TORCH_CHECK(false, "Unsupported architecture");
......
...@@ -31,6 +31,11 @@ struct LocalArch { ...@@ -31,6 +31,11 @@ struct LocalArch {
const auto base = arch_name.substr(0, arch_name.find(':')); const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx936" || base == "gfx938"; return base == "gfx936" || base == "gfx938";
} }
bool is_gfx938() const {
const auto base = arch_name.substr(0, arch_name.find(':'));
return base == "gfx938";
}
}; };
static int int64_stride_to_int(int64_t stride) { static int int64_stride_to_int(int64_t stride) {
...@@ -38,13 +43,25 @@ static int int64_stride_to_int(int64_t stride) { ...@@ -38,13 +43,25 @@ static int int64_stride_to_int(int64_t stride) {
return static_cast<int>(stride); return static_cast<int>(stride);
} }
static int default_num_splits(int topk, int extra_topk) { static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
if (extra_topk > 0) { if (extra_topk > 0) {
return 2; return 2;
} }
if (topk == 1024) return 16;
if (topk == 512) return 8; int split = 1;
return 1; if (topk > 1024) {
split = 32;
} else if (topk == 1024) {
split = 16;
} else if (topk == 512) {
split = 8;
}
constexpr int64_t kMaxDecodeTasksBeforeReducingSplit = 2048;
while (split > 1 && static_cast<int64_t>(b) * s_q * split > kMaxDecodeTasksBeforeReducingSplit) {
split /= 2;
}
return split;
} }
static void check_optional_extra( static void check_optional_extra(
...@@ -74,7 +91,7 @@ run( ...@@ -74,7 +91,7 @@ run(
int d_v, int d_v,
float sm_scale) { float sm_scale) {
LocalArch arch; LocalArch arch;
TORCH_CHECK(arch.is_gfx93x(), "DSA BF16 sparse decode is only supported on gfx936/gfx938"); TORCH_CHECK(arch.is_gfx938(), "DSA BF16 sparse decode is only supported on gfx938");
KU_CHECK_NDIM(q, 4); KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4); KU_CHECK_NDIM(kv, 4);
...@@ -97,17 +114,13 @@ run( ...@@ -97,17 +114,13 @@ run(
TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode"); TORCH_CHECK(b > 0 && s_q > 0 && h_q > 0, "Invalid q shape for DSA BF16 sparse decode");
TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1"); TORCH_CHECK(h_kv == 1, "DSA BF16 sparse decode only supports h_kv == 1");
TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128"); TORCH_CHECK(h_q == 64 || h_q == 128, "DSA BF16 sparse decode only supports h_q == 64 or 128");
TORCH_CHECK(d_qk == 512 || d_qk == 576, "DSA BF16 sparse decode only supports d_qk == 512 or 576"); TORCH_CHECK(d_qk == 512, "DSA BF16 sparse decode only supports d_qk == 512 for now");
TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512"); TORCH_CHECK(d_v == 512, "DSA BF16 sparse decode only supports d_v == 512");
TORCH_CHECK(topk > 0, "topk must be positive"); TORCH_CHECK(topk > 0, "topk must be positive");
if (has_extra) { if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode supports extra_topk <= 1024");
TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive"); TORCH_CHECK(extra_kv->size(1) > 0, "extra page_block_size must be positive");
TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv"); TORCH_CHECK(extra_kv->size(2) == h_kv, "extra_kv h_kv must match kv h_kv");
TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk"); TORCH_CHECK(extra_kv->size(3) == d_qk, "extra_kv d_qk must match q d_qk");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode supports topk <= 1024");
} }
check_optional_extra(extra_kv, extra_indices, extra_topk_length); check_optional_extra(extra_kv, extra_indices, extra_topk_length);
...@@ -167,7 +180,7 @@ run( ...@@ -167,7 +180,7 @@ run(
at::Tensor scores_sum = scores_memory.select(0, 1); at::Tensor scores_sum = scores_memory.select(0, 1);
if (!num_splits.has_value()) { if (!num_splits.has_value()) {
const int split = default_num_splits(topk, extra_topk); const int split = default_num_splits(b, s_q, topk, extra_topk);
num_splits = torch::empty({1}, opts.dtype(torch::kInt32)); num_splits = torch::empty({1}, opts.dtype(torch::kInt32));
num_splits->fill_(split); num_splits->fill_(split);
} }
...@@ -177,6 +190,14 @@ run( ...@@ -177,6 +190,14 @@ run(
TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor"); TORCH_CHECK(num_splits->numel() == 1, "DSA BF16 sparse decode expects num_splits to be a scalar tensor");
const int requested_num_splits = num_splits->item<int>(); const int requested_num_splits = num_splits->item<int>();
TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64"); TORCH_CHECK(requested_num_splits >= 1 && requested_num_splits <= 64, "DSA BF16 sparse decode requires 1 <= num_splits <= 64");
if (requested_num_splits == 1) {
if (has_extra) {
TORCH_CHECK(topk <= 256, "DSA BF16 sparse decode with extra_kv and num_splits == 1 supports topk <= 256");
TORCH_CHECK(extra_topk <= 1024, "DSA BF16 sparse decode with extra_kv and num_splits == 1 supports extra_topk <= 1024");
} else {
TORCH_CHECK(topk <= 1024, "DSA BF16 sparse decode with num_splits == 1 supports topk <= 1024");
}
}
Flash_fwd_mla_params_dsa params; Flash_fwd_mla_params_dsa params;
std::memset(&params, 0, sizeof(params)); std::memset(&params, 0, sizeof(params));
...@@ -246,6 +267,11 @@ run( ...@@ -246,6 +267,11 @@ run(
params.seqlenq_ngroups_swapped = true; params.seqlenq_ngroups_swapped = true;
params.is_seqlens_k_cumulative = false; params.is_seqlens_k_cumulative = false;
params.splitkv_use_fp32_as_accum = false; params.splitkv_use_fp32_as_accum = false;
constexpr int64_t kBufferLoadPaddedTokenLimit = 32LL * 64 * 1024;
const int64_t padded_k_tokens = static_cast<int64_t>(kv.size(0)) * page_block_size;
const int64_t padded_extra_k_tokens = has_extra ? static_cast<int64_t>(extra_kv->size(0)) * extra_kv->size(1) : 0;
params.decode_use_c_load = padded_k_tokens > kBufferLoadPaddedTokenLimit ||
padded_extra_k_tokens > kBufferLoadPaddedTokenLimit;
params.num_splits = requested_num_splits; params.num_splits = requested_num_splits;
params.partition_size = topk + params.extra_topk; params.partition_size = topk + params.extra_topk;
if (params.num_splits > 1) { if (params.num_splits > 1) {
...@@ -263,11 +289,7 @@ run( ...@@ -263,11 +289,7 @@ run(
} }
hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream()); hipStream_t stream = reinterpret_cast<hipStream_t>(at::cuda::getCurrentCUDAStream().stream());
if (d_qk == 512) { gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream);
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 512, 512>(params, stream);
} else {
gfx93::fwd::dsa_mls::run_dsa_prefill_nopage_64_dispatch<BFloat16, 576, 512>(params, stream);
}
return {out, lse, tile_scheduler_metadata, num_splits}; return {out, lse, tile_scheduler_metadata, num_splits};
} }
......
...@@ -72,7 +72,7 @@ void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) { ...@@ -72,7 +72,7 @@ void run_dsa_mla_splitkv_reduce(Params& params, hipStream_t stream) {
} }
template<typename T, int Headdim, int HeaddimV> template<typename T, int Headdim, int HeaddimV>
void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) { static inline void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStream_t stream) {
constexpr int kBlockM = 64; constexpr int kBlockM = 64;
constexpr int kBlockN = 64; constexpr int kBlockN = 64;
constexpr int WARP_M = 16; constexpr int WARP_M = 16;
...@@ -100,10 +100,12 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -100,10 +100,12 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] { BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64< BOOL_SWITCH(params.decode_use_c_load, DecodeCLoad, [&] {
Kernel_traits, true, Is_dropout, false, Is_causal, flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<
IsEvenMNConst, true, false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa> Kernel_traits, true, Is_dropout, false, Is_causal,
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); IsEvenMNConst, true, false, Is_MTP, 0, DecodeCLoad, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
});
}); });
}); });
}); });
...@@ -124,10 +126,25 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr ...@@ -124,10 +126,25 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.topk == 2048) { if (params.topk == 2048) {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64< constexpr bool CanUseFastTopk2048 = Headdim == 576 && HeaddimV == 512;
Kernel_traits, true, Is_dropout, false, Is_causal, if constexpr (CanUseFastTopk2048) {
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa> if (params.seqlen_k < params.topk) {
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, true, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, false, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
} else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64<
Kernel_traits, true, Is_dropout, false, Is_causal,
IsEvenMNConst, true, false, Is_MTP, 0, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}
} else { } else {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024< flash::flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024<
Kernel_traits, true, Is_dropout, false, Is_causal, Kernel_traits, true, Is_dropout, false, Is_causal,
......
...@@ -27,7 +27,9 @@ bool should_run(const SparseAttnFwdParams& params) { ...@@ -27,7 +27,9 @@ bool should_run(const SparseAttnFwdParams& params) {
return true; return true;
} }
if (params.d_qk == 576 && params.h_q == 64 && params.topk == 2048 && params.s_kv >= 32768) { if (params.d_qk == 576 && params.topk == 2048 &&
((params.h_q == 64 && params.s_kv >= 24576) ||
(params.h_q == 128 && params.s_kv >= 8192))) {
return true; return true;
} }
......
...@@ -429,6 +429,7 @@ struct Flash_fwd_mla_params_dsa { ...@@ -429,6 +429,7 @@ struct Flash_fwd_mla_params_dsa {
bool seqlenq_ngroups_swapped; bool seqlenq_ngroups_swapped;
bool is_seqlens_k_cumulative; bool is_seqlens_k_cumulative;
bool splitkv_use_fp32_as_accum; bool splitkv_use_fp32_as_accum;
bool decode_use_c_load;
// not used params // not used params
float *__restrict__ scales_q_ptr; float *__restrict__ scales_q_ptr;
......
...@@ -1841,9 +1841,9 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_ ...@@ -1841,9 +1841,9 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_
} }
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN> template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool DecodeCLoad = false>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999( __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999(
Element* k_faker, Element* v_ptr_raw,
vec4_uint k_ptr, vec4_uint k_ptr,
vec4_uint v_ptr, vec4_uint v_ptr,
Element* q_lds, Element* q_lds,
...@@ -1896,6 +1896,12 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_ ...@@ -1896,6 +1896,12 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_
index_topk[2] = index_ptr[(n_loop_real * 64) + 2 * 16 + (tid / 4)]; index_topk[2] = index_ptr[(n_loop_real * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop_real * 64) + 3 * 16 + (tid / 4)]; index_topk[3] = index_ptr[(n_loop_real * 64) + 3 * 16 + (tid / 4)];
bool invalid_topk[4];
invalid_topk[0] = index_topk[0] == -1;
invalid_topk[1] = index_topk[1] == -1;
invalid_topk[2] = index_topk[2] == -1;
invalid_topk[3] = index_topk[3] == -1;
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0]; int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0]; index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1]; index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
...@@ -1910,9 +1916,22 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_ ...@@ -1910,9 +1916,22 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_
+ (index_block * batch_stride + index_offset * seqlen_v_stride) * ELEMENT_BYTES / 4; + (index_block * batch_stride + index_offset * seqlen_v_stride) * ELEMENT_BYTES / 4;
int g_offset_s = warp_id * 32 * ELEMENT_BYTES / 4; int g_offset_s = warp_id * 32 * ELEMENT_BYTES / 4;
int g_offset_s_2 = warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4; int g_offset_s_2 = warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
using uint4_t = unsigned int __attribute__((ext_vector_type(4)));
const uint4_t zero4 = {0u, 0u, 0u, 0u};
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v); if constexpr (DecodeCLoad) {
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v); const int64_t token_base_elem = int64_t(index_block) * batch_stride + int64_t(index_offset) * seqlen_v_stride;
const int64_t src0_elem = token_base_elem + int64_t(g_offset_s + tid % 4 * 4) * 4 / ELEMENT_BYTES;
const int64_t src1_elem = token_base_elem + int64_t(g_offset_s_2 + tid % 4 * 4) * 4 / ELEMENT_BYTES;
const int64_t lane_store_elem = int64_t(tid) * 16 / ELEMENT_BYTES;
const int64_t dst0_elem = int64_t(lds_offset) * 4 / ELEMENT_BYTES + lane_store_elem;
const int64_t dst1_elem = int64_t(lds_offset_2) * 4 / ELEMENT_BYTES + lane_store_elem;
*reinterpret_cast<uint4_t*>(v_lds + dst0_elem) = invalid_topk[0] ? zero4 : *reinterpret_cast<const uint4_t*>(v_ptr_raw + src0_elem);
*reinterpret_cast<uint4_t*>(v_lds + dst1_elem) = invalid_topk[0] ? zero4 : *reinterpret_cast<const uint4_t*>(v_ptr_raw + src1_elem);
} else {
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
int lds_stage_id = 1; int lds_stage_id = 1;
...@@ -1928,12 +1947,28 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_ ...@@ -1928,12 +1947,28 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_
g_offset_s = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4; g_offset_s = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4;
g_offset_s_2 = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4; g_offset_s_2 = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v); if constexpr (DecodeCLoad) {
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v); const int64_t token_base_elem = int64_t(index_block) * batch_stride + int64_t(index_offset) * seqlen_v_stride;
const int64_t src0_elem = token_base_elem + int64_t(g_offset_s + tid % 4 * 4) * 4 / ELEMENT_BYTES;
const int64_t src1_elem = token_base_elem + int64_t(g_offset_s_2 + tid % 4 * 4) * 4 / ELEMENT_BYTES;
const int64_t lane_store_elem = int64_t(tid) * 16 / ELEMENT_BYTES;
const int64_t dst0_elem = int64_t(lds_offset) * 4 / ELEMENT_BYTES + lane_store_elem;
const int64_t dst1_elem = int64_t(lds_offset_2) * 4 / ELEMENT_BYTES + lane_store_elem;
const bool invalid_index = invalid_topk[total_loop / 2];
*reinterpret_cast<uint4_t*>(v_lds + dst0_elem) = invalid_index ? zero4 : *reinterpret_cast<const uint4_t*>(v_ptr_raw + src0_elem);
*reinterpret_cast<uint4_t*>(v_lds + dst1_elem) = invalid_index ? zero4 : *reinterpret_cast<const uint4_t*>(v_ptr_raw + src1_elem);
} else {
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
} }
// 不对称MLS指令 // 不对称MLS指令
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS); if constexpr (DecodeCLoad) {
flash::wait_all_warp_arrived();
} else {
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
}
lds_stage_id ^= 1; lds_stage_id ^= 1;
int stage_id = 0; int stage_id = 0;
...@@ -2237,6 +2272,398 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_ ...@@ -2237,6 +2272,398 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_
} }
// Old topk=2048 prefill fast path. Keep separate from the generic paged/fallback variant.
template<bool NeedIndexGuard, bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999_topk2048_fast(
Element* k_faker,
vec4_uint k_ptr,
vec4_uint v_ptr,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockK / 32)][2],
vec4_Accum<ElementAccum> pv_reg[(kHeadDimV / kBlockN) * (WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride,
int n_loop_real,
int max_seq_q_offset=0,
int max_seq_kv_offset=0) {
constexpr int WARP_NUM = kBlockM * kBlockN / (WARP_M * WARP_N);
constexpr int WARP_K = 16;
constexpr int READ_ONCE_COUNT = 16 * 32;
constexpr int kHeadDimV_OPT = 256; // lds 32x32x8x2B == 16KB
constexpr int V_LDS_LOAD_NUM = (kHeadDimV_OPT * WARP_K) / READ_ONCE_COUNT;
constexpr int V_LOAD_REQUESTS = V_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
static_assert (kBlockK >= 32, "Error: pv gemm kBlockK must be equal or greater than 32");
static_assert (kBlockM >= WARP_M, "Error: pv gemm kBlockM must be equal or greater than WARP_M");
static_assert (kBlockN == WARP_N, "Error: pv gemm kBlockN must be equal to WARP_N");
static_assert (WARP_K == 16 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert (WARP_M == 16 and "Error: To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 256 and "Error: To simplify, only WARP_N = 32 is supported!");
// 计算 V lds 起始偏移量
int v_lds_base = reinterpret_cast<size_t>(v_lds);
int tid = threadIdx.x % 64;
// 准备 V 寄存器
union_vec4_f16x2<Element> v_reg[WARP_N / 32];
// MLS
vec4_uint v_srsrc;
v_srsrc[0] = v_ptr[0];
v_srsrc[1] = v_ptr[1];
v_srsrc[2] = seqlen_v_stride; // stride
v_srsrc[3] = 0;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop_real * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop_real * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop_real * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop_real * 64) + 3 * 16 + (tid / 4)];
if constexpr (NeedIndexGuard) {
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
}
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2 = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
int g_offset_v = tid % 4 * 4 + index_topk[0] * seqlen_v_stride * ELEMENT_BYTES / 4;
int g_offset_s = warp_id * 32 * ELEMENT_BYTES / 4;
int g_offset_s_2 = warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
int lds_stage_id = 1;
for(int total_loop = 1; total_loop < (kHeadDimV / kBlockN) * 4; ++total_loop) {
{
g_offset_v = tid % 4 * 4 + index_topk[total_loop / 2] * seqlen_v_stride * ELEMENT_BYTES / 4;
lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * WARP_N + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4;
g_offset_s_2 = (total_loop % 2) * WARP_N * ELEMENT_BYTES / 4 + warp_id * 32 * ELEMENT_BYTES / 4 + 128 * ELEMENT_BYTES / 4;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
// 不对称MLS指令
flash::wait_buffer_data_arrived<true>(V_LOAD_REQUESTS);
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 0;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 1;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 2;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 3;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 4;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 5;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 6;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 7;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[(total_loop - 1) / 4][((total_loop - 1) / 2 ) % 2].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
// 等回最后的q_panel
flash::wait_buffer_data_arrived<true>(0);
lds_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// K DS PRE
stage_id ^= 1;
{
int v_lds_load_offset = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 0 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
int v_lds_load_offset_2 = v_lds_base + (lds_stage_id * WARP_N * WARP_K + 1 * 16 * 64 + 16 * 128) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(v_lds_load_offset, v_reg[stage_id * 2].f16, v_reg[stage_id * 2 + 1].f16, false);
DS_READ_MATRIX_32X32_B16(v_lds_load_offset_2, v_reg[4 + stage_id * 2].f16, v_reg[4 + stage_id * 2 + 1].f16, false);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 0;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 1;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 2;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 3;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 4;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 5;
int v_tile_id = stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
int abc[1];
int index_topk_qk = index_ptr[(((n_loop_real+1) % 16) * 64) + warp_id * 16];
int offset_m = index_topk_qk * seqlen_k_stride;
auto g_abc = (reinterpret_cast<uint64_t>(k_faker + offset_m));
inline_s_load_dword(abc[0], g_abc, 0);
// flash::raise_priority();
{
int min_tile_nk = 0;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 6;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
{
int min_tile_nk = 1;
#pragma unroll
for(int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
int pv_tile_id = lds_stage_id * 8 + 7;
int v_tile_id = 4 + stage_id * 2 + min_tile_nk;
pv_reg[pv_tile_id][min_tile_n].f32 = mmac_4interleave<Element, ElementAccum>(
p_reg[1][1].f16x4,
v_reg[v_tile_id].f16x4[min_tile_n],
pv_reg[pv_tile_id][min_tile_n].f32);
}
}
// flash::lower_priority();
}
template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN> template<bool PREFETCH_K, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN>
__forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_2( __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_2(
vec4_uint q_ptr, vec4_uint q_ptr,
......
...@@ -1727,6 +1727,12 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -1727,6 +1727,12 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)]; index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)]; index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
bool invalid_topk[4];
invalid_topk[0] = index_topk[0] == -1;
invalid_topk[1] = index_topk[1] == -1;
invalid_topk[2] = index_topk[2] == -1;
invalid_topk[3] = index_topk[3] == -1;
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0]; int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0]; index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1]; index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
...@@ -2097,8 +2103,9 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2097,8 +2103,9 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
} }
} }
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA> // Old topk=2048 prefill fast path. Keep separate from the generic paged/fallback variant.
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888( template<bool NeedIndexGuard, int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_777_topk2048_fast(
vec4_uint qv_ptr, vec4_uint qv_ptr,
vec4_uint q_ptr, vec4_uint q_ptr,
vec4_uint k_ptr, vec4_uint k_ptr,
...@@ -2116,6 +2123,429 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2116,6 +2123,429 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
int* index_ptr, int* index_ptr,
int batch_stride_k, int batch_stride_k,
int batch_stride_v, int batch_stride_v,
int n_loop,
int max_seq_q_offset=0,
int max_seq_k_offset=0) {
// Simplify
static_assert (kBlockK == 64 and "To simplify, only kBlockK = 32 is supported!");
static_assert (WARP_M == 16 and "To simplify, only WARP_M = 16 is supported!");
static_assert (WARP_N == 64 and "To simplify, only WARP_N = 64 is supported!");
constexpr int WARP_NUM = kBlockM / WARP_M;
constexpr int kHeadDim_OPT = (kHeadDim == 576) ? 64 : kHeadDim;
constexpr int Q_LDS_LOAD_NUM = (kBlockM * kBlockK) / (16 * 32);
constexpr int Q_LOAD_REQUESTS = Q_LDS_LOAD_NUM / WARP_NUM;
constexpr int K_LDS_LOAD_NUM = (kHeadDim_OPT * WARP_N) / (32 * 16);
constexpr int K_LOAD_REQUESTS = K_LDS_LOAD_NUM / WARP_NUM;
constexpr int ELEMENT_BYTES = sizeof(Element);
__builtin_amdgcn_sched_barrier(0);
if constexpr (kBlockN == 128) {
inline_vgpr4_init_zero_4x4x4(s_reg);
} else {
for (int i = 0; i < (WARP_M / 16) * (kBlockN / 32); ++i) {
for (int j = 0; j < 2; ++j) {
s_reg[i][j].u64[0] = 0.0f;
s_reg[i][j].u64[1] = 0.0f;
}
}
}
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int index_topk[4];
index_topk[0] = index_ptr[(n_loop * 64) + 0 * 16 + (tid / 4)];
index_topk[1] = index_ptr[(n_loop * 64) + 1 * 16 + (tid / 4)];
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
if constexpr (NeedIndexGuard) {
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
}
// 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
// 计算 q_lds,k_lds 的起始偏移量
int k_lds_base = reinterpret_cast<size_t>(k_lds);
for(int i=3;i>=0;i--)
{
int k_stage_id = 0;
int lds_offset = __builtin_amdgcn_readfirstlane((warp_id * 32 * 16) * ELEMENT_BYTES / 4);
int lds_offset_2;
int g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk[i] * seqlen_k_stride * ELEMENT_BYTES / 4;
int g_offset_s = 512 * ELEMENT_BYTES / 4 + warp_id * 16;
int g_offset_s_2;
flash::wait_all_warp_arrived();
if(warp_id < 2){
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
}
if constexpr (STAGES == 2) {
k_stage_id ^= 1;
}
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
int stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[16].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[17].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
lds_offset = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16) * ELEMENT_BYTES / 4);
lds_offset_2 = __builtin_amdgcn_readfirstlane((k_stage_id * 16 * 256 + warp_id * 32 * 16 + 128 * 16) * ELEMENT_BYTES / 4);
g_offset_s = 0 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = 0 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[8].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[9].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[10].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[11].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[12].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[13].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[14].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[15].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_buffer_data_arrived<true>(0);
k_stage_id ^= 1;
stage_id = 0;
// K DS
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 64) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// K DS PRE
stage_id ^= 1;
{
int k_lds_load_offset = k_lds_base + (k_stage_id * 16 * 256 + 16 * 128) * ELEMENT_BYTES;
int k_lds_load_offset_2 = k_lds_base + (k_stage_id * 16 * 256 + 16 * 192) * ELEMENT_BYTES;
DS_READ_MATRIX_32X32_B16(k_lds_load_offset, k_reg[stage_id * 2].f16, k_reg[stage_id * 2 + 1].f16, true);
DS_READ_MATRIX_32X32_B16(k_lds_load_offset_2, k_reg[4 + stage_id * 2].f16, k_reg[4 + stage_id * 2 + 1].f16, true);
}
// Wait DS
flash::wait_lds_data_arrived<false>(6);
// flash::raise_priority();
// MMAC
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[0].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[1].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(4);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[2].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[3].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(2);
// flash::raise_priority();
stage_id ^= 1;
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[4].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[5].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
flash::wait_lds_data_arrived<false>(0);
// flash::raise_priority();
{
int min_tile_n = 0;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[6].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
{
int min_tile_n = 1;
#pragma unroll
for(int min_tile_k = 0; min_tile_k < 2; ++min_tile_k) {
int k_tile_id = 4 + stage_id * 2 + min_tile_n;
s_reg[i/2][i%2].f32 = mmac_4interleave<Element, ElementAccum>(
q_reg[7].f16x4[min_tile_k],
k_reg[k_tile_id].f16x4[min_tile_k],
s_reg[i/2][i%2].f32);
}
}
// flash::lower_priority();
}
}
template<int kHeadDim, int kHeadDimV, int kBlockM, int kBlockN, int kBlockK, int WARP_M, int WARP_N, int STAGES, typename Element, typename ElementAccum, bool Is_even_MN, bool Is_FlashMLA, bool DecodeCLoad = false>
__forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888(
vec4_uint qv_ptr,
vec4_uint q_ptr,
vec4_uint k_ptr,
Element* k_ptr_raw,
Element* q_lds,
Element* k_lds,
Element* v_lds,
union_vec4_f16x2<Element> q_reg[(WARP_M * kBlockK) / (16 * 32) * (kHeadDim / kBlockK)],
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2],
int warp_id,
int seqlen_qv_stride,
int __seqlen_q_stride,
int seqlen_k_stride,
int seqlen_v_stride,
int* index_ptr,
int batch_stride_k,
int batch_stride_v,
int page_block_size, int page_block_size,
int n_loop, int n_loop,
int max_seq_q_offset=0, int max_seq_q_offset=0,
...@@ -2153,10 +2583,17 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2153,10 +2583,17 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)]; index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)]; index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
index_topk[0] = (index_topk[0] == -1) ? 0 : index_topk[0]; bool invalid_topk[4];
index_topk[1] = (index_topk[1] == -1) ? 0 : index_topk[1]; invalid_topk[0] = index_topk[0] == -1;
index_topk[2] = (index_topk[2] == -1) ? 0 : index_topk[2]; invalid_topk[1] = index_topk[1] == -1;
index_topk[3] = (index_topk[3] == -1) ? 0 : index_topk[3]; invalid_topk[2] = index_topk[2] == -1;
invalid_topk[3] = index_topk[3] == -1;
int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
// 准备 q,k 寄存器 // 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2]; union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
...@@ -2166,6 +2603,8 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2166,6 +2603,8 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
int k_stage_id = 0; int k_stage_id = 0;
int stage_id; int stage_id;
using uint4_t = unsigned int __attribute__((ext_vector_type(4)));
const uint4_t zero4 = {0u, 0u, 0u, 0u};
int index_block = index_topk[3] / page_block_size; int index_block = index_topk[3] / page_block_size;
int index_offset = index_topk[3] - index_block * page_block_size; int index_offset = index_topk[3] - index_block * page_block_size;
...@@ -2176,8 +2615,19 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2176,8 +2615,19 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
int g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16; int g_offset_s = 256 * ELEMENT_BYTES / 4 + warp_id * 16;
int g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64; int g_offset_s_2 = 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v); if constexpr (DecodeCLoad) {
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v); const int64_t token_base_elem = int64_t(index_block) * batch_stride_k + int64_t(index_offset) * seqlen_k_stride;
const int64_t src0_elem = token_base_elem + int64_t(g_offset_s + (((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16)) * 4 / ELEMENT_BYTES;
const int64_t src1_elem = token_base_elem + int64_t(g_offset_s_2 + (((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16)) * 4 / ELEMENT_BYTES;
const int64_t lane_store_elem = int64_t(tid) * 16 / ELEMENT_BYTES;
const int64_t dst0_elem = int64_t(lds_offset) * 4 / ELEMENT_BYTES + lane_store_elem;
const int64_t dst1_elem = int64_t(lds_offset_2) * 4 / ELEMENT_BYTES + lane_store_elem;
*reinterpret_cast<uint4_t*>(k_lds + dst0_elem) = invalid_topk[3] ? zero4 : *reinterpret_cast<const uint4_t*>(k_ptr_raw + src0_elem);
*reinterpret_cast<uint4_t*>(k_lds + dst1_elem) = invalid_topk[3] ? zero4 : *reinterpret_cast<const uint4_t*>(k_ptr_raw + src1_elem);
} else {
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
k_stage_id ^= 1; k_stage_id ^= 1;
...@@ -2195,10 +2645,27 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2195,10 +2645,27 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
g_offset_s = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16; g_offset_s = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16;
g_offset_s_2 = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64; g_offset_s_2 = (i % 2) * 256 * ELEMENT_BYTES / 4 + warp_id * 16 + 64;
flash::wait_all_warp_arrived(); flash::wait_all_warp_arrived();
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v); if constexpr (DecodeCLoad) {
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v); const int64_t token_base_elem = int64_t(index_block) * batch_stride_k + int64_t(index_offset) * seqlen_k_stride;
const int lane_dword = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16;
const int64_t src0_elem = token_base_elem + int64_t(g_offset_s + lane_dword) * 4 / ELEMENT_BYTES;
const int64_t src1_elem = token_base_elem + int64_t(g_offset_s_2 + lane_dword) * 4 / ELEMENT_BYTES;
const int64_t lane_store_elem = int64_t(tid) * 16 / ELEMENT_BYTES;
const int64_t dst0_elem = int64_t(lds_offset) * 4 / ELEMENT_BYTES + lane_store_elem;
const int64_t dst1_elem = int64_t(lds_offset_2) * 4 / ELEMENT_BYTES + lane_store_elem;
const bool invalid_index = invalid_topk[i / 2];
*reinterpret_cast<uint4_t*>(k_lds + dst0_elem) = invalid_index ? zero4 : *reinterpret_cast<const uint4_t*>(k_ptr_raw + src0_elem);
*reinterpret_cast<uint4_t*>(k_lds + dst1_elem) = invalid_index ? zero4 : *reinterpret_cast<const uint4_t*>(k_ptr_raw + src1_elem);
} else {
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset, g_offset_s, g_offset_v);
inline_buffer_load_dwordx4_lds(k_lds, k_ptr, lds_offset_2, g_offset_s_2, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS); if constexpr (DecodeCLoad) {
flash::wait_all_warp_arrived();
} else {
flash::wait_buffer_data_arrived<true>(K_LOAD_REQUESTS);
}
k_stage_id ^= 1; k_stage_id ^= 1;
stage_id = 0; stage_id = 0;
...@@ -2533,10 +3000,11 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_ ...@@ -2533,10 +3000,11 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_
index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)]; index_topk[2] = index_ptr[(n_loop * 64) + 2 * 16 + (tid / 4)];
index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)]; index_topk[3] = index_ptr[(n_loop * 64) + 3 * 16 + (tid / 4)];
index_topk[0] = (index_topk[0] == -1) ? 0 : index_topk[0]; int fallback_index = index_ptr[0] == -1 ? 0 : index_ptr[0];
index_topk[1] = (index_topk[1] == -1) ? 0 : index_topk[1]; index_topk[0] = (index_topk[0] == -1) ? fallback_index : index_topk[0];
index_topk[2] = (index_topk[2] == -1) ? 0 : index_topk[2]; index_topk[1] = (index_topk[1] == -1) ? fallback_index : index_topk[1];
index_topk[3] = (index_topk[3] == -1) ? 0 : index_topk[3]; index_topk[2] = (index_topk[2] == -1) ? fallback_index : index_topk[2];
index_topk[3] = (index_topk[3] == -1) ? fallback_index : index_topk[3];
// 准备 q,k 寄存器 // 准备 q,k 寄存器
union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2]; union_vec4_f16x2<Element> k_reg[STAGES * (32 * kBlockK) / (32 * 32) * 2];
......
...@@ -2513,6 +2513,332 @@ __forceinline__ __device__ void flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_n ...@@ -2513,6 +2513,332 @@ __forceinline__ __device__ void flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_n
#endif #endif
} }
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool NeedIndexGuard, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64(const Params params) {
#if defined(__gfx938__)
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int WARP_M = Kernel_traits::kWaveM;
constexpr int WARP_N = Kernel_traits::kWaveN;
constexpr int STAGES = Kernel_traits::STAGES;
constexpr int WARP_NUM = kBlockM / WARP_M;
{
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
int warp_id_vec = threadIdx.x / 64;
const int warp_id = __builtin_amdgcn_readfirstlane(warp_id_vec);
const int m_block = blockIdx.x;
// Is_FlashMLA为true则无qv,为False则有qv,适配FA3和FlashMLA接口
constexpr bool Is_FlashMLA = !Is_prefix;
// 获取当前任务的长度
int sum_s_q = Is_FlashMLA ? 0: params.max_seqlen;
int actual_seqlen_q = params.seqlen_q;
int actual_seqlen_k = 2048;
// 处理边界
if (m_block * kBlockM >= actual_seqlen_q) return;
const int warp_offset_in_seq_q = m_block * kBlockM + warp_id * WARP_M;
const int warp_seqq_limit = Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM;
// 分配 lds Q/P same place, K/V same place;
// extern __shared__ Element smem[];
// int* index_lds = (int *)&(smem);
// Element* q_lds = (Element*)(index_lds + 1024 + 256); // 16KB
// Element* k_lds = q_lds + ((8*1024) / sizeof(Element)); // 16KB
// Element* v_lds = q_lds; // 32KB
extern __shared__ Element smem[];
Element* q_lds = (Element*)&(smem); // 16KB
Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024);
// int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
const int linear_q_start = m_block * kBlockM;
const int query_idx = linear_q_start / params.ngroups;
const int q_head_start = linear_q_start - query_idx * params.ngroups;
constexpr int real_topk = 2048;
const int n_block_min = 0;
int n_block_max = real_topk / kBlockN;
// 计算数据跨度
int seqlen_q_stride = params.q_head_stride;
int seqlen_k_stride = params.k_row_stride;
int seqlen_v_stride = params.v_row_stride;
int seqlen_o_stride = params.o_head_stride;
int seqlen_qv_stride = params.qv_row_stride; // 当走FlashMLA接口时,不会使用
int row_offset_q, row_offset_k, row_offset_v, row_offset_o, row_offset_lse, row_offset_qv;
int headdim_split_id = 0;
// const int page_block_size = params.page_block_size;
// int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int *index_ptr = params.sparse_indices
+ bidb * params.sparse_indices_batch_stride
+ query_idx * params.sparse_indices_row_stride
+ (q_head_start / params.ngroups) * params.sparse_indices_head_stride;
// const int block_table_idx = 0;
// const int block_table_offset = 0;
row_offset_q = bidb * int64_t(params.q_batch_stride) + query_idx * int64_t(params.q_token_stride) + q_head_start * int64_t(params.q_head_stride);
row_offset_k = 0;
row_offset_v = 0;
row_offset_o = bidb * int64_t(params.o_batch_stride) + query_idx * int64_t(params.o_row_stride) + q_head_start * int64_t(params.o_head_stride);
row_offset_lse = bidb * params.seqlen_q + m_block * kBlockM;
// row_offset_k = bidb * int64_t(params.k_batch_stride) + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
// row_offset_v = bidb * int64_t(params.v_batch_stride) + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
// 根据起始数据偏移量准备 Q/K/V 的 buffer resource 寄存器
// q_ptr : 64 | qv_ptr : 512 | k_ptr : 64(k_rope) | v_ptr : 512(k_nope)
vec4_uint qv_ptr;
auto q_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.q_ptr) + row_offset_q);
if constexpr (!Is_FlashMLA)
qv_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.qv_ptr) + row_offset_qv);
auto k_ptr = prepare_for_matrix_load<kHeadDim>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto k_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k);
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1;
int tid = threadIdx.x % 64;
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
if constexpr (NeedIndexGuard) {
flash::wait_buffer_data_arrived<true>(0);
int index_offset = warp_id * 256 + tid * 4;
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = index_offset + i;
int index_value = index_lds[local_index];
index_lds[local_index] = (index_value < 0 || index_value >= params.seqlen_k) ? -1 : index_value;
}
flash::wait_all_warp_arrived();
}
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr bool PREFETCH_K = (Is_even_MN) and( ( kHeadDim == 576 and kHeadDimV == 512 )); // 简单场景下开启
constexpr bool ALLOW_PREFETCH = (STAGES > 1); // 客观上决定是否开启 prefetch
if constexpr (PREFETCH_K and ALLOW_PREFETCH) {
if (n_block_min < n_block_max - n_masking_steps) {
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, Is_even_MN ? 0: actual_seqlen_q - m_block * kBlockM);
prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, Is_even_MN ? 0: actual_seqlen_k - n_block_min * kBlockN);
}
}
vec_Accum<ElementAccum> scores_max[WARP_M / 16]; // 只处理16行,所以只有1个reg
vec_Accum<ElementAccum> scores_sum[WARP_M / 16];
vec4_Accum<ElementAccum> acc_o[(kHeadDimV / kBlockK) * (WARP_M / 16) * (kBlockK / 32)][2];
// 内联失败,手动展开
{
constexpr int K_LOOP_COUNT = kHeadDimV / kBlockK;
constexpr int M_WARP_COUNT = WARP_M / 16;
constexpr int K_WARP_COUNT = kBlockK / 32;
constexpr int M_MMAC_COUNT = 1;
#pragma unroll
for (int i = 0; i < M_WARP_COUNT; ++i) {
scores_max[i].f32[0] = -INFINITY;
scores_sum[i].f32[0] = 0;
}
uint64_t pk_zero = 0;
#pragma unroll
for (int i = 0; i < K_LOOP_COUNT * M_WARP_COUNT * K_WARP_COUNT; ++i) {
#pragma unroll
for (int min_tile_m = 0; min_tile_m < M_MMAC_COUNT; ++min_tile_m) {
#pragma unroll
for (int min_tile_n = 0; min_tile_n < 2; ++min_tile_n) {
// #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
// acc_o[i][min_tile_n].u64[0] = __builtin_hcu_mov_b64(pk_zero);
// acc_o[i][min_tile_n].u64[1] = __builtin_hcu_mov_b64(pk_zero);
// #else
acc_o[i][min_tile_n].f32[0] = 0;
acc_o[i][min_tile_n].f32[1] = 0;
acc_o[i][min_tile_n].f32[2] = 0;
acc_o[i][min_tile_n].f32[3] = 0;
// #endif
}
}
}
}
union_vec4_f16x2<Element> q_reg[(WARP_M * 64) / (16 * 64) * (kHeadDim / kBlockK)];
auto QK_GEMM_FUNC_new = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_q<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>;
prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, 64, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, warp_seqq_limit);
QK_GEMM_FUNC_new(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_ptr_buffer /* v512 */, q_lds, k_lds, v_lds, q_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_ptr, params.k_batch_stride, params.v_batch_stride, warp_seqq_limit);
auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_777_topk2048_fast<NeedIndexGuard, kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>;
auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_576_512<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999_topk2048_fast<NeedIndexGuard, false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>;
// Mainloop, 主循环, 不做 causal mask 的部分
for (int n_block_loop = n_block_min; n_block_loop < 16; ++n_block_loop) {
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: actual_seqlen_k - warp_offset_in_seqkv;
// 预取 K 的数据到 lds // 预取 k_ptr(k_rope), q_ptr 到 k_lds, q_lds
// if constexpr (not PREFETCH_K) {
// // prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, warp_seqq_limit);
// prefetch_k_to_lds_mls_ds_576_512_buffer_load_nopage<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr_buffer, k_lds, warp_id, seqlen_k_stride, index_lds, params.k_batch_stride, n_block_loop, warp_seqkv_limit);
// }
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2];
// QK gemm
if constexpr (Is_FlashMLA) {
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} else {
// QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit);
}
// if constexpr (!Is_causal) {
// if constexpr (!Is_even_MN) { prefill_mla_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_seqkv_limit); }
// } else {
// if constexpr (Is_MTP) {
// if constexpr (Is_FlashMLA) {
// flashmla_apply_mtp_mask_causal_gfx938<vec4_Accum<ElementAccum>, kBlockN, WARP_M, WARP_NUM>(s_reg, warp_offset_in_seqkv/* n_block_loop * kBlockN */, actual_seqlen_k, warp_offset_in_seq_q/* m_block * kBlockM + warp_id_row * 16 */, actual_seqlen_q, params.ngroups, params.mtp);
// } else {
// prefill_mla_apply_mtp_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_offset_in_seqkv, actual_seqlen_k, warp_offset_in_seq_q, actual_seqlen_q);
// }
// } else {
// prefill_mla_apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_offset_in_seqkv, actual_seqlen_k, warp_offset_in_seq_q, actual_seqlen_q);
// }
// }
if constexpr (NeedIndexGuard) {
prefill_dsa_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, index_lds, warp_offset_in_seqkv, real_topk);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
prefill_mla_softmax_rescale_o<false, Is_causal, vec4_Accum<ElementAccum>, vec_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockN / 32)][2];
prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg);
// PV gemm
PV_GEMM_FUNC_IN_MASK(k_faker, k_ptr, v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.v_batch_stride, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
// int abc[1];
// int index_topk = index_ptr[(n_block_loop * 64) + warp_id * 16];
// int offset_m = index_topk * seqlen_k_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(k_faker + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
}
lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
g_offset_v = tid * 4;
g_offset_s = 1024 + warp_id * 256;
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
flash::wait_buffer_data_arrived<true>(0);
if constexpr (NeedIndexGuard) {
int index_offset = warp_id * 256 + tid * 4;
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = index_offset + i;
int index_value = index_lds[local_index];
index_lds[local_index] = (index_value < 0 || index_value >= params.seqlen_k) ? -1 : index_value;
}
flash::wait_all_warp_arrived();
}
for (int n_block_loop = 16; n_block_loop < n_block_max; ++n_block_loop) {
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int warp_offset_in_seqkv = n_block_loop * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0: actual_seqlen_k - warp_offset_in_seqkv;
int n_loop_real = n_block_loop - 16;
// 预取 K 的数据到 lds // 预取 k_ptr(k_rope), q_ptr 到 k_lds, q_lds
// if constexpr (not PREFETCH_K) {
// // prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, warp_seqq_limit);
// prefetch_k_to_lds_mls_ds_576_512_buffer_load_nopage<kHeadDim, kBlockN, kBlockK, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr_buffer, k_lds, warp_id, seqlen_k_stride, index_lds, params.k_batch_stride, n_loop_real, warp_seqkv_limit);
// }
// 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2];
// QK gemm
if constexpr (Is_FlashMLA) {
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, n_loop_real, warp_seqq_limit, warp_seqkv_limit);
} else {
// QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit);
}
// if constexpr (!Is_causal) {
// if constexpr (!Is_even_MN) { prefill_mla_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_seqkv_limit); }
// } else {
// if constexpr (Is_MTP) {
// if constexpr (Is_FlashMLA) {
// flashmla_apply_mtp_mask_causal_gfx938<vec4_Accum<ElementAccum>, kBlockN, WARP_M, WARP_NUM>(s_reg, warp_offset_in_seqkv/* n_block_loop * kBlockN */, actual_seqlen_k, warp_offset_in_seq_q/* m_block * kBlockM + warp_id_row * 16 */, actual_seqlen_q, params.ngroups, params.mtp);
// } else {
// prefill_mla_apply_mtp_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_offset_in_seqkv, actual_seqlen_k, warp_offset_in_seq_q, actual_seqlen_q);
// }
// } else {
// prefill_mla_apply_mask_causal_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, warp_offset_in_seqkv, actual_seqlen_k, warp_offset_in_seq_q, actual_seqlen_q);
// }
// }
if constexpr (NeedIndexGuard) {
prefill_dsa_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, index_lds, warp_offset_in_seqkv, real_topk);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
prefill_mla_softmax_rescale_o<false, Is_causal, vec4_Accum<ElementAccum>, vec_Accum<ElementAccum>, kHeadDimV, kBlockK, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// softmax(QK) f32 -> f16
union_vec2_f16x2<Element> p_reg[(WARP_M / 16) * (kBlockN / 32)][2];
prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg);
// PV gemm
PV_GEMM_FUNC_IN_MASK(k_faker, k_ptr, v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.v_batch_stride, n_loop_real, warp_seqq_limit, warp_seqkv_limit);
// int abc[1];
// int32_t offset_m = (((n_block_loop + 1) % 32) * 64) + warp_id * 16 + (tid / 4);
// auto g_abc = __builtin_amdgcn_readfirstlane(reinterpret_cast<uint64_t>(index_ptr + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec_Accum<ElementAccum> lse[WARP_M / 16];
prefill_mla_epilugue_rescale_acco<WARP_M, kBlockK, kHeadDimV, Is_dropout && Is_training, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(acc_o, lse, scores_max, scores_sum, params.scale_softmax, 0/* params.rp_dropout */);
/**************************************************************************************************************************************/
constexpr bool Is_Interleave = true;
int lane_id = threadIdx.x & 63;
if (params.softmax_lse_ptr != nullptr) {
prefill_mla_epilogue_store_lse<WARP_M, Is_even_MN, false/*SplitD*/, Is_Interleave, ElementAccum, vec_Accum<ElementAccum>, 1/* M_MMAC_COUNT */>(lse, params.softmax_lse_ptr, row_offset_lse, warp_id, lane_id, 0, actual_seqlen_q - m_block * kBlockM);
}
/**************************************************************************************************************************************/
Element* o_ptr = reinterpret_cast<Element *>(params.o_ptr) + row_offset_o;
prefill_mla_epilogue_store_output<kHeadDimV, kBlockM, kBlockK, WARP_M, Is_even_MN, Is_Interleave, false/*TcpSwizzle*/, Element, ElementAccum, 1/* M_MMAC_COUNT */>(o_ptr, acc_o, m_block, warp_id, lane_id, seqlen_o_stride, actual_seqlen_q);
}
#endif
}
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params> template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64(const Params params) { __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64(const Params params) {
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, true, Params>(params); flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, true, Params>(params);
...@@ -2523,7 +2849,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2523,7 +2849,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, false, Params>(params); flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_impl<Kernel_traits, Is_training, Is_dropout, Is_prefix, Is_causal, Is_even_MN, Is_even_K, Return_softmax, Is_MTP, Layout, false, Params>(params);
} }
template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool Has_extra, typename Params> template<typename Kernel_traits, bool Is_training, bool Is_dropout, bool Is_prefix, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, bool Is_MTP, int Layout, bool DecodeCLoad, bool Has_extra, typename Params>
__global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64(const Params params) { __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64(const Params params) {
#if defined(__gfx938__) #if defined(__gfx938__)
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -2645,6 +2971,8 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2645,6 +2971,8 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
} }
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k; auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v; auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
auto extra_k_faker = reinterpret_cast<Element*>(Has_extra ? params.extra_k_ptr : params.k_ptr);
auto extra_v_faker = reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr);
// apply causal mask 的步骤和 no causal mask 的步骤分开算 // apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1; constexpr int n_masking_steps = 1;
...@@ -2728,9 +3056,9 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2728,9 +3056,9 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
QK_GEMM_FUNC_new(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_ptr_buffer /* v512 */, q_lds, k_lds, v_lds, q_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_ptr, params.k_batch_stride, params.v_batch_stride, warp_seqq_limit); QK_GEMM_FUNC_new(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_ptr_buffer /* v512 */, q_lds, k_lds, v_lds, q_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_ptr, params.k_batch_stride, params.v_batch_stride, warp_seqq_limit);
auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>; auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA, DecodeCLoad>;
auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_576_512<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>; auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_576_512<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>; auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN, DecodeCLoad>;
// Mainloop, 主循环, 不做 causal mask 的部分 // Mainloop, 主循环, 不做 causal mask 的部分
for (int n_block_loop = 0; n_block_loop < main_num_blocks; ++n_block_loop) { for (int n_block_loop = 0; n_block_loop < main_num_blocks; ++n_block_loop) {
...@@ -2754,7 +3082,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2754,7 +3082,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>( qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>(
q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit); q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} else { } else {
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit); QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, k_faker, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.k_batch_stride, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} }
} else { } else {
// QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit); // QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit);
...@@ -2782,7 +3110,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2782,7 +3110,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg); prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg);
// PV gemm // PV gemm
PV_GEMM_FUNC_IN_MASK(k_faker, k_ptr, v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit); PV_GEMM_FUNC_IN_MASK(v_faker, k_ptr, v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_lds, params.v_batch_stride, params.page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} }
...@@ -2804,7 +3132,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2804,7 +3132,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
// QK gemm // QK gemm
if constexpr (Is_FlashMLA) { if constexpr (Is_FlashMLA) {
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, extra_k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, params.extra_k_row_stride, params.extra_v_row_stride, extra_index_lds, params.extra_k_batch_stride, params.extra_v_batch_stride, params.extra_page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit); QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, extra_k_ptr_buffer /* k576 */, extra_k_faker, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, params.extra_k_row_stride, params.extra_v_row_stride, extra_index_lds, params.extra_k_batch_stride, params.extra_v_batch_stride, params.extra_page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} else { } else {
// QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit); // QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit);
} }
...@@ -2831,7 +3159,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2831,7 +3159,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg); prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg);
// PV gemm // PV gemm
PV_GEMM_FUNC_IN_MASK(k_faker, k_ptr, extra_v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, params.extra_k_row_stride, params.extra_v_row_stride, extra_index_lds, params.extra_v_batch_stride, params.extra_page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit); PV_GEMM_FUNC_IN_MASK(extra_v_faker, k_ptr, extra_v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, params.extra_k_row_stride, params.extra_v_row_stride, extra_index_lds, params.extra_v_batch_stride, params.extra_page_block_size, n_block_loop, warp_seqq_limit, warp_seqkv_limit);
} }
...@@ -2903,7 +3231,8 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2903,7 +3231,8 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
Element* k_lds = q_lds; // 16KB Element* k_lds = q_lds; // 16KB
Element* v_lds = q_lds; Element* v_lds = q_lds;
int* index_lds = (int *)(q_lds + 8 * 1024); int* index_lds = (int *)(q_lds + 8 * 1024);
int* extra_index_lds = (int *)(index_lds + 256); constexpr int kIndexChunk = 1024;
constexpr int kIndexChunkBlocks = kIndexChunk / kBlockN;
// int* sIndices = (int *)(q_lds + 8192); // int* sIndices = (int *)(q_lds + 8192);
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置 // 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
int split_id = blockIdx.y; int split_id = blockIdx.y;
...@@ -2924,7 +3253,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2924,7 +3253,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
int n_block_min = split_id * blocks_per_split; int n_block_min = split_id * blocks_per_split;
int n_block_max = min(n_block_min + blocks_per_split, total_num_blocks); int n_block_max = min(n_block_min + blocks_per_split, total_num_blocks);
if (n_block_max <= n_block_min) return; // Empty split still needs to write neutral accum/lse for the reduce kernel.
// 计算数据跨度 // 计算数据跨度
int seqlen_q_stride = params.q_head_stride; int seqlen_q_stride = params.q_head_stride;
int seqlen_k_stride = params.k_row_stride; int seqlen_k_stride = params.k_row_stride;
...@@ -2977,49 +3306,12 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -2977,49 +3306,12 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr = prepare_for_matrix_load<kHeadDimV>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v); auto v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v);
auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr)); auto extra_v_ptr_buffer = tcp_cache_swizzle_func<64, Element>(reinterpret_cast<Element*>(Has_extra ? params.extra_v_ptr : params.v_ptr));
auto index_ptr_buffer = tcp_cache_swizzle_func<0, int>(reinterpret_cast<int*>(index_ptr));
int tid = threadIdx.x % 64; int tid = threadIdx.x % 64;
if constexpr (Has_extra) {
auto extra_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(extra_index_ptr);
int lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
int g_offset_v = tid * 4;
int g_offset_s = warp_id * 256;
inline_buffer_load_dwordx4_lds(extra_index_lds, extra_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k; auto k_faker = reinterpret_cast<Element*>(params.k_ptr) + row_offset_k;
auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v; auto v_faker = reinterpret_cast<Element*>(params.v_ptr) + row_offset_v;
// apply causal mask 的步骤和 no causal mask 的步骤分开算 // apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr int n_masking_steps = 1; constexpr int n_masking_steps = 1;
int lds_offset = __builtin_amdgcn_readfirstlane(Has_extra ? warp_id * 64 : warp_id * 4 * 64);
int g_offset_v = Has_extra ? tid : tid * 4;
int g_offset_s = Has_extra ? warp_id * 64 : warp_id * 256;
if constexpr (Has_extra) {
inline_buffer_load_dword_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
} else {
inline_buffer_load_dwordx4_lds(index_lds, index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived();
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = Has_extra ? warp_id * 64 + tid : warp_id * 256 + tid * 4 + i;
if (local_index >= main_topk_length) {
index_lds[local_index] = main_topk_length > 0 ? index_lds[main_topk_length - 1] : -1;
}
}
if constexpr (Has_extra) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
if (local_index >= extra_topk_length) {
extra_index_lds[local_index] = extra_topk_length > 0 ? extra_index_lds[extra_topk_length - 1] : -1;
}
}
}
flash::wait_all_warp_arrived();
// 是否做 prefetch K, PV 结束后, prefetch K 有风险 // 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr bool PREFETCH_K = (Is_even_MN) and( ( kHeadDim == 512 and kHeadDimV == 512 )); // 简单场景下开启 constexpr bool PREFETCH_K = (Is_even_MN) and( ( kHeadDim == 512 and kHeadDimV == 512 )); // 简单场景下开启
constexpr bool ALLOW_PREFETCH = (STAGES > 1); // 客观上决定是否开启 prefetch constexpr bool ALLOW_PREFETCH = (STAGES > 1); // 客观上决定是否开启 prefetch
...@@ -3070,24 +3362,61 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -3070,24 +3362,61 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
QK_GEMM_FUNC_new(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_ptr_buffer /* v512 */, q_lds, k_lds, v_lds, q_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_ptr, params.k_batch_stride, params.v_batch_stride, warp_seqq_limit); QK_GEMM_FUNC_new(q_ptr /* q576 */, q_ptr /* q576 */, k_ptr_buffer /* k576 */, v_ptr_buffer /* v512 */, q_lds, k_lds, v_lds, q_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, index_ptr, params.k_batch_stride, params.v_batch_stride, warp_seqq_limit);
auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>; auto QK_GEMM_FUNC = &qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA, true>;
auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_576_512<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>; auto PV_GEMM_FUNC = &pv_gemm_prefetch_k_mls_ds_576_512<PREFETCH_K, kHeadDim, kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, STAGES, Element, ElementAccum, Is_even_MN>;
auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN>; auto PV_GEMM_FUNC_IN_MASK = &pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999<false, kHeadDim, kHeadDimV, kBlockM, 256, kBlockN, WARP_M, 256, STAGES, Element, ElementAccum, Is_even_MN, true>;
for (int logical_block = n_block_min; logical_block < n_block_max; ++logical_block) { int logical_chunk_start = n_block_min;
bool is_extra = Has_extra && logical_block >= main_num_blocks; while (logical_chunk_start < n_block_max) {
int rel_block = is_extra ? logical_block - main_num_blocks : logical_block; const bool chunk_is_extra = Has_extra && logical_chunk_start >= main_num_blocks;
int cur_topk_length = is_extra ? extra_topk_length : main_topk_length; const int chunk_base = chunk_is_extra ? main_num_blocks : 0;
int* cur_index_lds = is_extra ? extra_index_lds : index_lds; const int chunk_rel_block_start = ((logical_chunk_start - chunk_base) / kIndexChunkBlocks) * kIndexChunkBlocks;
auto cur_k_ptr_buffer = is_extra ? extra_k_ptr_buffer : k_ptr_buffer; const int chunk_topk_length = chunk_is_extra ? extra_topk_length : main_topk_length;
auto cur_v_ptr_buffer = is_extra ? extra_v_ptr_buffer : v_ptr_buffer; const int chunk_index_width = chunk_is_extra ? params.extra_topk : params.topk;
int cur_seqlen_k_stride = is_extra ? params.extra_k_row_stride : seqlen_k_stride; int* chunk_index_ptr = chunk_is_extra ? extra_index_ptr : index_ptr;
int cur_seqlen_v_stride = is_extra ? params.extra_v_row_stride : seqlen_v_stride; auto chunk_index_ptr_buffer = tcp_cache_swizzle_func<0, int>(chunk_index_ptr + chunk_rel_block_start * kBlockN);
int cur_k_batch_stride = is_extra ? params.extra_k_batch_stride : params.k_batch_stride;
int cur_v_batch_stride = is_extra ? params.extra_v_batch_stride : params.v_batch_stride; #pragma unroll
int cur_page_block_size = is_extra ? params.extra_page_block_size : params.page_block_size; for (int index_load_iter = 0; index_load_iter < 4; ++index_load_iter) {
int warp_offset_in_seqkv = rel_block * kBlockN; const int local_index_base = index_load_iter * 256 + warp_id * 64;
int warp_seqkv_limit = Is_even_MN ? 0 : cur_topk_length - warp_offset_in_seqkv; const int global_index_base = chunk_rel_block_start * kBlockN + local_index_base;
if (global_index_base < chunk_index_width) {
int lds_offset = __builtin_amdgcn_readfirstlane(local_index_base);
int g_offset_v = tid;
int g_offset_s = local_index_base;
inline_buffer_load_dword_lds(index_lds, chunk_index_ptr_buffer, lds_offset, g_offset_s, g_offset_v);
}
}
flash::wait_buffer_data_arrived<true>(0);
flash::wait_all_warp_arrived();
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int local_index = warp_id * 256 + tid * 4 + i;
const int global_index = chunk_rel_block_start * kBlockN + local_index;
if (global_index >= chunk_topk_length) {
index_lds[local_index] = -1;
}
}
const int logical_chunk_aligned_end = chunk_base + chunk_rel_block_start + kIndexChunkBlocks;
const int logical_side_end = chunk_is_extra ? total_num_blocks : main_num_blocks;
const int chunk_block_end = min(min(n_block_max, logical_chunk_aligned_end), logical_side_end);
for (int logical_block = logical_chunk_start; logical_block < chunk_block_end; ++logical_block) {
int rel_block = chunk_is_extra ? logical_block - main_num_blocks : logical_block;
int rel_block_in_chunk = rel_block - chunk_rel_block_start;
int cur_topk_length = chunk_is_extra ? extra_topk_length : main_topk_length;
auto cur_k_ptr_buffer = chunk_is_extra ? extra_k_ptr_buffer : k_ptr_buffer;
auto cur_v_ptr_buffer = chunk_is_extra ? extra_v_ptr_buffer : v_ptr_buffer;
int cur_seqlen_k_stride = chunk_is_extra ? params.extra_k_row_stride : seqlen_k_stride;
int cur_seqlen_v_stride = chunk_is_extra ? params.extra_v_row_stride : seqlen_v_stride;
int cur_k_batch_stride = chunk_is_extra ? params.extra_k_batch_stride : params.k_batch_stride;
int cur_v_batch_stride = chunk_is_extra ? params.extra_v_batch_stride : params.v_batch_stride;
int cur_page_block_size = chunk_is_extra ? params.extra_page_block_size : params.page_block_size;
int warp_offset_in_seqkv = rel_block * kBlockN;
int warp_offset_in_chunk = rel_block_in_chunk * kBlockN;
int chunk_real_topk = cur_topk_length - chunk_rel_block_start * kBlockN;
int warp_seqkv_limit = Is_even_MN ? 0 : cur_topk_length - warp_offset_in_seqkv;
// 准备 QK gemm 输出的寄存器 // 准备 QK gemm 输出的寄存器
vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2]; vec4_Accum<ElementAccum> s_reg[(WARP_M / 16) * (kBlockN / 32)][2];
...@@ -3096,16 +3425,17 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -3096,16 +3425,17 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
if constexpr (Is_FlashMLA) { if constexpr (Is_FlashMLA) {
if constexpr (kHeadDim == 576) { if constexpr (kHeadDim == 576) {
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>( qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999<kHeadDim, kHeadDimV, kBlockM, kBlockN, 64, WARP_M, WARP_N, STAGES, Element, ElementAccum, Is_even_MN, Is_FlashMLA>(
q_ptr /* q576 */, q_ptr /* q576 */, cur_k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, cur_index_lds, cur_k_batch_stride, cur_v_batch_stride, cur_page_block_size, rel_block, warp_seqq_limit, warp_seqkv_limit); q_ptr /* q576 */, q_ptr /* q576 */, cur_k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, index_lds, cur_k_batch_stride, cur_v_batch_stride, cur_page_block_size, rel_block_in_chunk, warp_seqq_limit, warp_seqkv_limit);
} else { } else {
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, cur_k_ptr_buffer /* k576 */, v_faker /* v512 */, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, cur_index_lds, cur_k_batch_stride, cur_v_batch_stride, cur_page_block_size, rel_block, warp_seqq_limit, warp_seqkv_limit); Element* cur_k_ptr_raw = reinterpret_cast<Element*>(chunk_is_extra ? params.extra_k_ptr : params.k_ptr);
QK_GEMM_FUNC(q_ptr /* q576 */, q_ptr /* q576 */, cur_k_ptr_buffer /* k576 */, cur_k_ptr_raw, q_lds, k_lds, v_lds, q_reg, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, index_lds, cur_k_batch_stride, cur_v_batch_stride, cur_page_block_size, rel_block_in_chunk, warp_seqq_limit, warp_seqkv_limit);
} }
} else { } else {
// QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit); // QK_GEMM_FUNC(qv_ptr /* qv512 */, q_ptr /* q64 */, v_ptr /* k512 */, v_ptr /* kv512 */, q_lds, k_lds, v_lds, s_reg, warp_id, seqlen_qv_stride, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride, warp_seqq_limit, warp_seqkv_limit);
} }
if constexpr (!Is_causal) { if constexpr (!Is_causal) {
decode_dsa_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, cur_index_lds, warp_offset_in_seqkv,cur_topk_length); decode_dsa_apply_mask_gfx938<vec4_Accum<ElementAccum>, WARP_M, kBlockN, 1/* M_MMAC_COUNT */>(s_reg, index_lds, warp_offset_in_chunk, cur_topk_length);
} else { } else {
if constexpr (Is_MTP) { if constexpr (Is_MTP) {
if constexpr (Is_FlashMLA) { if constexpr (Is_FlashMLA) {
...@@ -3126,8 +3456,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa ...@@ -3126,8 +3456,11 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_mla_decode_kernel_gfx938_dsa
prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg); prefill_mla_convert_pk_type<WARP_M, kBlockN, Element, ElementAccum, 1/* M_MMAC_COUNT */>(p_reg, s_reg);
// PV gemm // PV gemm
PV_GEMM_FUNC_IN_MASK(k_faker, k_ptr, cur_v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, cur_index_lds, cur_v_batch_stride, cur_page_block_size, rel_block, warp_seqq_limit, warp_seqkv_limit); Element* cur_v_ptr_raw = reinterpret_cast<Element*>(chunk_is_extra ? params.extra_v_ptr : params.v_ptr);
PV_GEMM_FUNC_IN_MASK(cur_v_ptr_raw, k_ptr, cur_v_ptr_buffer, q_lds, k_lds, v_lds, p_reg, acc_o, warp_id, seqlen_q_stride, cur_seqlen_k_stride, cur_seqlen_v_stride, index_lds, cur_v_batch_stride, cur_page_block_size, rel_block_in_chunk, warp_seqq_limit, warp_seqkv_limit);
}
logical_chunk_start = chunk_block_end;
} }
// lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64); // lds_offset = __builtin_amdgcn_readfirstlane(warp_id * 4 * 64);
......
...@@ -562,7 +562,7 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params ...@@ -562,7 +562,7 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa &params
BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] { BOOL_SWITCH(params.mtp > 1/* is_mtp */, Is_MTP, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(has_extra, Has_extra, [&] { BOOL_SWITCH(has_extra, Has_extra, [&] {
flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, Has_extra, Flash_fwd_mla_params_dsa> flash::flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<Kernel_traits, true/*Is_training*/, Is_dropout, false/* Is_prefix | flashmla */,Is_causal, IsEvenMNConst, /*Is_even_K*/true, /*Return_softmax*/false, Is_MTP, 0, true, Has_extra, Flash_fwd_mla_params_dsa>
<<<dimGrid, dimBlock, 21 * 1024, stream>>>(params); <<<dimGrid, dimBlock, 21 * 1024, stream>>>(params);
}); });
}); });
......
...@@ -55,7 +55,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -55,7 +55,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
} }
// max-rescale coefficient x sum-rescale coefficient // max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp; lds[tx] = s_sum_tmp > 0.f ? s_sum_load_ori * s_max_ratio / s_sum_tmp : 0.f;
// finally, do rescale for each split and reduce the sum of them // finally, do rescale for each split and reduce the sum of them
// each block(1waves) process (num_splits x head_dim) elements in total // each block(1waves) process (num_splits x head_dim) elements in total
...@@ -76,7 +76,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel( ...@@ -76,7 +76,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
for (int i = 0; i < num_splits; ++i) { for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[i]; float s_scale = lds[i];
bool within_splits = (i < true_num_splits); bool within_splits = (i < true_num_splits) && (s_scale > 0.f);
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
// read ultimate scale value for current split // read ultimate scale value for current split
...@@ -197,7 +197,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -197,7 +197,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
s_sum_tmp = lds[LDS_ACCUM]; s_sum_tmp = lds[LDS_ACCUM];
// max-rescale coefficient x sum-rescale coefficient // max-rescale coefficient x sum-rescale coefficient
lds[tx] = s_sum_load_ori * s_max_ratio / s_sum_tmp; lds[tx] = s_sum_tmp > 0.f ? s_sum_load_ori * s_max_ratio / s_sum_tmp : 0.f;
// finally, do rescale for each split and reduce the sum of them // finally, do rescale for each split and reduce the sum of them
// each block(multiple waves) process (num_splits x head_dim) elements in total // each block(multiple waves) process (num_splits x head_dim) elements in total
...@@ -224,7 +224,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) { ...@@ -224,7 +224,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
for (int i = 0; i < split_count_this_wave; ++i) { for (int i = 0; i < split_count_this_wave; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[begin + i]; float s_scale = lds[begin + i];
bool within_splits = (begin + i) < true_num_splits; bool within_splits = ((begin + i) < true_num_splits) && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
...@@ -402,13 +402,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -402,13 +402,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
bool has_valid_lse = (lse_max_local != -INFINITY);
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = has_valid_lse ? __expf(lse_local - lse_max_local) : 0.f;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = has_valid_lse ? __logf(lse_local_logsum) + lse_max_local : -INFINITY;
// store softmax_lse // store softmax_lse
if (tx == 0) { if (tx == 0) {
...@@ -416,14 +417,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel ...@@ -416,14 +417,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
} }
// store rescale coefficient into lds // store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum); lds[tx] = has_valid_lse ? __expf(lse_local - lse_local_logsum) : 0.f;
// num_splits may not be 64, and thus need boundary judgement // num_splits may not be 64, and thus need boundary judgement
#pragma unroll #pragma unroll
for (int i = 0; i < num_splits; ++i) { for (int i = 0; i < num_splits; ++i) {
// read ultimate scale value for current split // read ultimate scale value for current split
float s_scale = lds[i]; float s_scale = lds[i];
bool within_splits = (i < true_num_splits); bool within_splits = (i < true_num_splits) && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
if constexpr (kHeadDim % 128 == 0) { if constexpr (kHeadDim % 128 == 0) {
...@@ -555,13 +556,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -555,13 +556,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step)); lse_max_local = max(lse_max_local, flash::__shfl_xor_tmp(lse_max_local, step));
} }
bool has_valid_lse = (lse_max_local != -INFINITY);
// reduce sum lse // reduce sum lse
float lse_local_logsum = __expf(lse_local - lse_max_local); float lse_local_logsum = has_valid_lse ? __expf(lse_local - lse_max_local) : 0.f;
#pragma unroll #pragma unroll
for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) { for (int step = SPLIT_COUNT >> 1; step > 0; step = (step >> 1)) {
lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step); lse_local_logsum = lse_local_logsum + flash::__shfl_xor_tmp(lse_local_logsum, step);
} }
lse_local_logsum = __logf(lse_local_logsum) + lse_max_local; lse_local_logsum = has_valid_lse ? __logf(lse_local_logsum) + lse_max_local : -INFINITY;
float attn_sink_o_scale = 1.0f; float attn_sink_o_scale = 1.0f;
if (params.attn_sink != nullptr) { if (params.attn_sink != nullptr) {
...@@ -578,7 +580,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -578,7 +580,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
} }
// store rescale coefficient into lds // store rescale coefficient into lds
lds[tx] = __expf(lse_local - lse_local_logsum) * attn_sink_o_scale; lds[tx] = has_valid_lse ? __expf(lse_local - lse_local_logsum) * attn_sink_o_scale : 0.f;
// num_splits may not be 64, and thus need boundary judgement // num_splits may not be 64, and thus need boundary judgement
#pragma unroll #pragma unroll
...@@ -586,6 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel( ...@@ -586,6 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// read ultimate scale value for current split // read ultimate scale value for current split
bool within_splits = ((i + wave_id) < true_num_splits); bool within_splits = ((i + wave_id) < true_num_splits);
float s_scale = num_splits >= WARP_NUM ? lds[i + wave_id]: (within_splits ? lds[i + wave_id]: 0.f); float s_scale = num_splits >= WARP_NUM ? lds[i + wave_id]: (within_splits ? lds[i + wave_id]: 0.f);
within_splits = within_splits && (s_scale > 0.f);
#pragma unroll #pragma unroll
for (int t = 0; t < tx_float_count; t += 2) { for (int t = 0; t < tx_float_count; t += 2) {
// half -> float32, reduce precision loss // half -> float32, reduce precision loss
......
#include "fwd.h" #include "fwd.h"
#include <cstdlib>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <hip/hip_runtime.h>
#include "dsa_mls/fwd.h" #include "dsa_mls/fwd.h"
#include "phase1.h" #include "phase1.h"
namespace gfx93 { namespace gfx93 {
namespace {
bool is_current_device_gfx938() {
int device = 0;
hipDeviceProp_t prop{};
if (hipGetDevice(&device) != hipSuccess || hipGetDeviceProperties(&prop, device) != hipSuccess) {
return false;
}
const std::string arch_name = prop.gcnArchName;
return arch_name.substr(0, arch_name.find(':')) == "gfx938";
}
} // namespace
void run_fwd_kernel(const SparseAttnFwdParams& params) { void run_fwd_kernel(const SparseAttnFwdParams& params) {
if (gfx93::fwd::dsa_mls::should_run(params)) { const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr;
const bool enable_dsa_mls_prefill = is_current_device_gfx938();
if (enable_dsa_mls_prefill && !disable_dsa_mls_prefill && gfx93::fwd::dsa_mls::should_run(params)) {
gfx93::fwd::dsa_mls::run(params); gfx93::fwd::dsa_mls::run(params);
return; return;
} }
......
...@@ -57,7 +57,7 @@ def flash_mla_with_kvcache( ...@@ -57,7 +57,7 @@ def flash_mla_with_kvcache(
cache_seqlens: Optional[torch.Tensor], cache_seqlens: Optional[torch.Tensor],
head_dim_v: int, head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta, tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None, num_splits: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
is_fp8_kvcache: bool = False, is_fp8_kvcache: bool = False,
...@@ -78,7 +78,7 @@ def flash_mla_with_kvcache( ...@@ -78,7 +78,7 @@ def flash_mla_with_kvcache(
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512 head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface). num_splits: optional override for BF16 sparse decode. Other paths keep using sched_meta.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool. is_fp8_kvcache: bool.
...@@ -104,7 +104,6 @@ def flash_mla_with_kvcache( ...@@ -104,7 +104,6 @@ def flash_mla_with_kvcache(
sched_meta = tile_scheduler_metadata sched_meta = tile_scheduler_metadata
indices_in_kvcache = indices indices_in_kvcache = indices
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
...@@ -155,14 +154,18 @@ def flash_mla_with_kvcache( ...@@ -155,14 +154,18 @@ def flash_mla_with_kvcache(
assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False" assert k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
if extra_k_cache is not None: if extra_k_cache is not None:
assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False" assert extra_k_cache.dtype == torch.bfloat16, "BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
else:
assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
decode_num_splits = num_splits if num_splits is not None else sched_meta.num_splits
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink, q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits, sched_meta.tile_scheduler_metadata, decode_num_splits,
extra_k_cache, extra_indices_in_kvcache, extra_topk_length, extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
head_dim_v, softmax_scale head_dim_v, softmax_scale
) )
else: else:
# Dense attention # Dense attention
assert num_splits is None, "num_splits override is only supported by BF16 sparse decode"
assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
......
version = '1.2.0'
git_hash = 'ac8223a'
git_branch = 'master-aicc'
abi = 'abi1'
dtk = '2604'
torch_version = '2.5'
hcu_version = '1.2.0+das.opt1.dtk2604'
...@@ -54,11 +54,9 @@ def is_bf16_decode_supported_param(t: TestParam) -> bool: ...@@ -54,11 +54,9 @@ def is_bf16_decode_supported_param(t: TestParam) -> bool:
return False return False
if t.h_q not in [64, 128]: if t.h_q not in [64, 128]:
return False return False
if t.d_qk not in [512, 576]: if t.d_qk != 512:
return False return False
if t.decode.extra_topk is None: return True
return t.topk <= 1024
return t.topk <= 256 and t.decode.extra_topk <= 1024
@dataclasses.dataclass @dataclasses.dataclass
class RawTestParamForDecode: class RawTestParamForDecode:
......
...@@ -110,6 +110,11 @@ def gen_testcase() -> List[RawTestParam]: ...@@ -110,6 +110,11 @@ def gen_testcase() -> List[RawTestParam]:
(RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# MODEL1 CONFIG4 # MODEL1 CONFIG4
(RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]),
# DSA BF16 large topk / extra_topk coverage
(RawTestParam(0, 64, 2, 1, 32768, True, topk=4096, d_qk=512, block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 128, 2, 1, 32768, True, topk=8192, d_qk=576, block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 64, 2, 1, 32768, True, topk=128, d_qk=512, extra_s_k=32768, extra_topk=4096, block_size=256, extra_block_size=256, check_correctness=True, num_runs=0), [1, 2]),
(RawTestParam(0, 128, 2, 1, 32768, True, topk=128, d_qk=512, extra_s_k=32768, extra_topk=8192, block_size=256, extra_block_size=256, have_extra_topk_length=True, check_correctness=True, num_runs=0), [1, 2]),
] ]
performance_cases = [ performance_cases = [
# Production cases # Production cases
...@@ -173,10 +178,13 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -173,10 +178,13 @@ def test_flash_mla(p: TestParam) -> Result:
result = kk.bench_kineto(run_decode, p.num_runs) result = kk.bench_kineto(run_decode, p.num_runs)
if lib.is_decode_bf16_kvcache(): if lib.is_decode_bf16_kvcache():
splitkv_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv" main_kernel_names = [
"flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv",
"flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64<",
]
combine_kernel_name = "flash_mla_splitkv_reduce_kernel" combine_kernel_name = "flash_mla_splitkv_reduce_kernel"
else: else:
splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" main_kernel_names = ["flash_fwd_splitkv_mla_fp8_sparse_kernel"]
combine_kernel_name = "flash_fwd_mla_combine_kernel" combine_kernel_name = "flash_fwd_mla_combine_kernel"
# Get individual kernel time usages # Get individual kernel time usages
...@@ -188,21 +196,24 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -188,21 +196,24 @@ def test_flash_mla(p: TestParam) -> Result:
kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6 kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6
else: else:
kernel_time_usages_us[kernel_name] = None kernel_time_usages_us[kernel_name] = None
pick_kernel_time_usage(splitkv_kernel_name) for main_kernel_name in main_kernel_names:
pick_kernel_time_usage(main_kernel_name)
pick_kernel_time_usage(combine_kernel_name) pick_kernel_time_usage(combine_kernel_name)
# Get E2E time usages # Get E2E time usages
def have_kernel(name: str): def have_kernel(name: str):
return kernel_time_usages_us[name] is not None return kernel_time_usages_us[name] is not None
active_main_kernel_name = next((name for name in main_kernel_names if have_kernel(name)), None)
if kk.is_using_profiling_tools(): if kk.is_using_profiling_tools():
e2e_time_usage_us = 1e6 e2e_time_usage_us = 1e6
else: else:
assert have_kernel(splitkv_kernel_name) assert active_main_kernel_name is not None
if have_kernel(combine_kernel_name): if have_kernel(combine_kernel_name):
e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6 e2e_time_usage_us = result.get_e2e_time(active_main_kernel_name, combine_kernel_name) * 1e6
else: else:
e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name] e2e_time_usage_us = kernel_time_usages_us[active_main_kernel_name]
assert e2e_time_usage_us is not None assert e2e_time_usage_us is not None
flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t) flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t)
...@@ -216,12 +227,14 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -216,12 +227,14 @@ def test_flash_mla(p: TestParam) -> Result:
print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us') print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us')
print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}') print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}')
print(f'Time (per): {e2e_time_usage_us:.1f} us') print(f'Time (per): {e2e_time_usage_us:.1f} us')
print_kernel_time_usage(splitkv_kernel_name, "Splitkv") for main_kernel_name in main_kernel_names:
print_kernel_time_usage(main_kernel_name, "Decode")
print_kernel_time_usage(combine_kernel_name, "Combine") print_kernel_time_usage(combine_kernel_name, "Combine")
print(f'TFlops: {achieved_tflops:.1f}') print(f'TFlops: {achieved_tflops:.1f}')
print(f'GB/s: {achieved_gBps:.0f}') print(f'GB/s: {achieved_gBps:.0f}')
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps) main_kernel_time_usage_us = kernel_time_usages_us[active_main_kernel_name] if active_main_kernel_name is not None else 0.0
performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, main_kernel_time_usage_us or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps)
is_correct = True is_correct = True
if p.check_correctness: if p.check_correctness:
...@@ -229,10 +242,11 @@ def test_flash_mla(p: TestParam) -> Result: ...@@ -229,10 +242,11 @@ def test_flash_mla(p: TestParam) -> Result:
with torch.profiler.record_function("reference_flash_mla"): with torch.profiler.record_function("reference_flash_mla"):
out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t)
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
if lib.is_decode_bf16_kvcache(): if lib.is_decode_bf16_kvcache():
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=6e-6)
is_correct &= is_out_correct is_correct &= is_out_correct
else: else:
is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6)
is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
is_correct &= is_out_correct and is_lse_correct is_correct &= is_out_correct and is_lse_correct
...@@ -261,10 +275,10 @@ def main(): ...@@ -261,10 +275,10 @@ def main():
bf16_testcases = [] bf16_testcases = []
seen_bf16_cases = set() seen_bf16_cases = set()
for t in testcases: for t in testcases:
if t.d_qk == 576:
t = dataclasses.replace(t, d_qk=512)
if not lib.is_bf16_decode_supported_param(t): if not lib.is_bf16_decode_supported_param(t):
continue continue
if t.num_runs > 0 and t.decode.b > 16:
t = dataclasses.replace(t, decode=dataclasses.replace(t.decode, b=16))
key = dataclasses.asdict(t) key = dataclasses.asdict(t)
key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None key["decode"] = tuple(key["decode"].items()) if key["decode"] is not None else None
key = tuple(key.items()) key = tuple(key.items())
......
...@@ -11,6 +11,8 @@ import ref ...@@ -11,6 +11,8 @@ import ref
_counter = kk.Counter() _counter = kk.Counter()
def is_dsa_mls_prefill_case(p: TestParam) -> bool: def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if get_gcn_arch_name() != "gfx938":
return False
if p.d_v != 512: if p.d_v != 512:
return False return False
if p.d_qk not in [512, 576]: if p.d_qk not in [512, 576]:
...@@ -26,7 +28,7 @@ def is_dsa_mls_prefill_case(p: TestParam) -> bool: ...@@ -26,7 +28,7 @@ def is_dsa_mls_prefill_case(p: TestParam) -> bool:
if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)): if p.d_qk == 512 and ((p.h_q == 64 and p.topk == 512) or (p.h_q == 128 and p.topk == 1024)):
return True return True
if p.d_qk == 576 and p.h_q == 64 and p.topk == 2048 and p.s_kv >= 32768: if p.d_qk == 576 and p.topk == 2048 and ((p.h_q == 64 and p.s_kv >= 24576) or (p.h_q == 128 and p.s_kv >= 8192)):
return True return True
return False return False
...@@ -53,9 +55,16 @@ def run_test(p: TestParam) -> bool: ...@@ -53,9 +55,16 @@ def run_test(p: TestParam) -> bool:
flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t)
bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs) bench_result = kk.bench_kineto(run_prefill, num_tests=p.num_runs)
kernel_names = bench_result.get_kernel_names() kernel_names = bench_result.get_kernel_names()
prefill_kernel_name = "sparse_attn_fwd" prefill_kernel_name_candidates = [
if not any(prefill_kernel_name in name for name in kernel_names): "sparse_attn_fwd",
prefill_kernel_name = "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64" "flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_topk2048_fast_nopage_64",
"flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64",
]
prefill_kernel_name = next(
(candidate for candidate in prefill_kernel_name_candidates
if any(candidate in name for name in kernel_names)),
prefill_kernel_name_candidates[0],
)
prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name) prefill_ans_time = bench_result.get_kernel_time(prefill_kernel_name)
prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12
prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12
...@@ -69,6 +78,8 @@ def run_test(p: TestParam) -> bool: ...@@ -69,6 +78,8 @@ def run_test(p: TestParam) -> bool:
is_correct = True is_correct = True
is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6)
# DSA MLS prefill is selected for throughput and currently only treats out as the validated contract.
# max_logits/lse can differ on boundary cases, so keep those checks on the Sugon path only.
if not is_dsa_mls_prefill_case(p): if not is_dsa_mls_prefill_case(p):
is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
...@@ -177,13 +188,13 @@ if __name__ == '__main__': ...@@ -177,13 +188,13 @@ if __name__ == '__main__':
performance_case_templates = [ performance_case_templates = [
# V3.2 # V3.2
(576, 128, 2048, [8192, 32768, 65536, 98304, 131072]), (576, 128, 2048, [8192, 16384, 65536, 98304, 131072]),
(576, 64, 2048, [8192, 32768, 65536, 98304, 131072]), (576, 64, 2048, [8192, 16384, 65536, 98304, 131072]),
# MODEL1 CONFIG1 # MODEL1 CONFIG1
(512, 64, 512, [8192, 32768, 49152, 65536]), # (512, 64, 512, [8192, 32768, 49152, 65536]),
# MODEL1 CONFIG2 # MODEL1 CONFIG2
(512, 128, 1024, [8192, 32768, 49152, 65536]), # (512, 128, 1024, [8192, 32768, 49152, 65536]),
(512, 16, 1024, [8192, 32768, 49152, 65536]), # (512, 16, 1024, [8192, 32768, 49152, 65536]),
] ]
performance_cases = [ performance_cases = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment