fwd.cu 998 Bytes
Newer Older
1
2
#include "fwd.h"

3
#include <stdexcept>
4

shenzhe's avatar
shenzhe committed
5
#include "dsa_mls/fwd.h"
6
#include "phase1.h"
7

zhanghj2's avatar
zhanghj2 committed
8
namespace gfx93 {
9

10
void run_fwd_kernel(const SparseAttnFwdParams& params) {
shenzhe's avatar
shenzhe committed
11
12
13
14
15
    if (gfx93::fwd::dsa_mls::should_run(params)) {
        gfx93::fwd::dsa_mls::run(params);
        return;
    }

16
    const bool have_topk_length = params.topk_length != nullptr;
17

18
19
20
    // Dispatch based on d_qk dimension and presence of topk_length
    if (params.d_qk == 512) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
21
            gfx93::fwd::run_fwd_phase1_kernel<512, true>(params);
22
        } else {
zhanghj2's avatar
zhanghj2 committed
23
            gfx93::fwd::run_fwd_phase1_kernel<512, false>(params);
24
        }
25
26
    } else if (params.d_qk == 576) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
27
            gfx93::fwd::run_fwd_phase1_kernel<576, true>(params);
28
        } else {
zhanghj2's avatar
zhanghj2 committed
29
            gfx93::fwd::run_fwd_phase1_kernel<576, false>(params);
30
31
        }
    } else {
32
        throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
33
34
35
    }
}

zhanghj2's avatar
zhanghj2 committed
36
}  // namespace gfx93