fwd.cu 1.61 KB
Newer Older
1
2
#include "fwd.h"

3
#include <cstdlib>
4
#include <stdexcept>
5
6
7
#include <string>

#include <hip/hip_runtime.h>
8

shenzhe's avatar
shenzhe committed
9
#include "dsa_mls/fwd.h"
10
#include "phase1.h"
11

zhanghj2's avatar
zhanghj2 committed
12
namespace gfx93 {
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
namespace {

bool is_current_device_gfx938() {
    int device = 0;
    hipDeviceProp_t prop{};
    if (hipGetDevice(&device) != hipSuccess || hipGetDeviceProperties(&prop, device) != hipSuccess) {
        return false;
    }
    const std::string arch_name = prop.gcnArchName;
    return arch_name.substr(0, arch_name.find(':')) == "gfx938";
}

}  // namespace

28
void run_fwd_kernel(const SparseAttnFwdParams& params) {
29
30
31
    const bool disable_dsa_mls_prefill = std::getenv("FLASH_MLA_DISABLE_DSA_MLS_PREFILL") != nullptr;
    const bool enable_dsa_mls_prefill = is_current_device_gfx938();
    if (enable_dsa_mls_prefill && !disable_dsa_mls_prefill && gfx93::fwd::dsa_mls::should_run(params)) {
shenzhe's avatar
shenzhe committed
32
33
34
35
        gfx93::fwd::dsa_mls::run(params);
        return;
    }

36
    const bool have_topk_length = params.topk_length != nullptr;
37

38
39
40
    // Dispatch based on d_qk dimension and presence of topk_length
    if (params.d_qk == 512) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
41
            gfx93::fwd::run_fwd_phase1_kernel<512, true>(params);
42
        } else {
zhanghj2's avatar
zhanghj2 committed
43
            gfx93::fwd::run_fwd_phase1_kernel<512, false>(params);
44
        }
45
46
    } else if (params.d_qk == 576) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
47
            gfx93::fwd::run_fwd_phase1_kernel<576, true>(params);
48
        } else {
zhanghj2's avatar
zhanghj2 committed
49
            gfx93::fwd::run_fwd_phase1_kernel<576, false>(params);
50
51
        }
    } else {
52
        throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
53
54
55
    }
}

zhanghj2's avatar
zhanghj2 committed
56
}  // namespace gfx93