#include "fwd.h" #include #include #include #include "dispatch.h" namespace gfx93::fwd::dsa_mls { bool can_run(const SparseAttnFwdParams& params) { if (params.d_v != 512) return false; if (params.d_qk != 512 && params.d_qk != 576) return false; if (params.h_kv != 1) return false; if (params.h_q != 64 && params.h_q != 128) return false; if (!(params.topk <= 1024 || params.topk == 2048)) return false; if (params.topk == 2048 && (params.attn_sink != nullptr || params.topk_length != nullptr)) return false; return true; } bool should_run(const SparseAttnFwdParams& params) { if (!can_run(params)) return false; if (params.d_qk == 512 && ((params.h_q == 64 && params.topk == 512) || (params.h_q == 128 && params.topk == 1024))) { return true; } if (params.d_qk == 576 && params.h_q == 64 && params.topk == 2048 && params.s_kv >= 32768) { return true; } return false; } static Flash_fwd_mla_params_dsa make_legacy_params(const SparseAttnFwdParams& src) { if (!can_run(src)) { throw std::runtime_error( "DSA MLS sparse prefill only supports d_qk=512/576, d_v=512, " "h_kv=1, h_q=64/128, topk<=1024 or topk=2048 without attn_sink/topk_length"); } Flash_fwd_mla_params_dsa dst; std::memset(&dst, 0, sizeof(dst)); dst.layout = 1; dst.b = 1; dst.h = 1; dst.h_k = src.h_kv; dst.h_h_k_ratio = dst.h / dst.h_k; dst.mtp = 1; dst.ngroups = src.h_q / src.h_kv; dst.topk = src.topk; dst.d = src.d_qk; dst.d_v = src.d_v; dst.scale_softmax = src.sm_scale; dst.scale_softmax_log2 = src.sm_scale_div_log2; dst.cu_seqlens_q = nullptr; dst.cu_seqlens_k = nullptr; dst.cu_seqlens_k_new = nullptr; dst.topk_length = src.topk_length; dst.attn_sink = src.attn_sink; dst.q_ptr = src.q; dst.k_ptr = src.kv; dst.v_ptr = src.kv; dst.o_ptr = src.out; dst.sparse_indices = src.indices; dst.softmax_lse_ptr = src.lse; dst.scores_max_ptr = src.max_logits; dst.scores_sum_ptr = nullptr; dst.block_table = nullptr; dst.block_table_batch_stride = 0; dst.page_block_size = 0; dst.is_causal = false; dst.q_batch_stride = 0; dst.q_token_stride = src.stride_q_s_q; dst.q_head_stride = src.stride_q_h_q; dst.q_row_stride = dst.q_head_stride; dst.k_batch_stride = 0; dst.k_row_stride = src.stride_kv_s_kv; dst.k_head_stride = src.stride_kv_h_kv; dst.v_batch_stride = 0; dst.v_row_stride = src.stride_kv_s_kv; dst.v_head_stride = src.stride_kv_h_kv; dst.o_batch_stride = 0; dst.o_row_stride = src.h_q * src.d_v; dst.o_head_stride = src.d_v; dst.sparse_indices_batch_stride = 0; dst.sparse_indices_row_stride = src.stride_indices_s_q; dst.sparse_indices_head_stride = src.stride_indices_h_kv; dst.sparse_indices_topk_stride = 1; dst.seqlen_q = src.s_q * dst.ngroups; dst.seqlen_k = src.s_kv; dst.max_seqlen = src.s_q; dst.is_bf16 = true; dst.is_e4m3 = false; dst.is_int8 = false; dst.cu_count = src.num_sm; dst.seqlenq_ngroups_swapped = true; dst.is_seqlens_k_cumulative = false; dst.splitkv_use_fp32_as_accum = false; dst.num_splits = 0; dst.partition_size = src.topk; return dst; } void run(const SparseAttnFwdParams& params) { Flash_fwd_mla_params_dsa legacy_params = make_legacy_params(params); hipStream_t stream = reinterpret_cast(params.stream); if (params.d_qk == 512) { run_dsa_prefill_nopage_64_dispatch(legacy_params, stream); } else if (params.d_qk == 576) { run_dsa_prefill_nopage_64_dispatch(legacy_params, stream); } else { throw std::runtime_error("Unsupported d_qk value in DSA MLS sparse prefill"); } } } // namespace gfx93::fwd::dsa_mls