#include "fwd.h" #include #include #include #include #include "dsa_mls/fwd.h" #include "phase1.h" namespace gfx93 { namespace { bool is_current_device_gfx938() { int device = 0; hipDeviceProp_t prop{}; if (hipGetDevice(&device) != hipSuccess || hipGetDeviceProperties(&prop, device) != hipSuccess) { return false; } const std::string arch_name = prop.gcnArchName; return arch_name.substr(0, arch_name.find(':')) == "gfx938"; } } // namespace void run_fwd_kernel(const SparseAttnFwdParams& params) { const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr; const bool enable_dsa_mls_prefill = is_current_device_gfx938(); if (enable_dsa_mls_prefill && !disable_dsa_mls_prefill && gfx93::fwd::dsa_mls::should_run(params)) { gfx93::fwd::dsa_mls::run(params); return; } const bool have_topk_length = params.topk_length != nullptr; // Dispatch based on d_qk dimension and presence of topk_length if (params.d_qk == 512) { if (have_topk_length) { gfx93::fwd::run_fwd_phase1_kernel<512, true>(params); } else { gfx93::fwd::run_fwd_phase1_kernel<512, false>(params); } } else if (params.d_qk == 576) { if (have_topk_length) { gfx93::fwd::run_fwd_phase1_kernel<576, true>(params); } else { gfx93::fwd::run_fwd_phase1_kernel<576, false>(params); } } else { throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel"); } } } // namespace gfx93