fwd.cu 857 Bytes
Newer Older
1
2
#include "fwd.h"

3
#include <stdexcept>
4

5
#include "phase1.h"
6

zhanghj2's avatar
zhanghj2 committed
7
namespace gfx93 {
8

9
10
void run_fwd_kernel(const SparseAttnFwdParams& params) {
    const bool have_topk_length = params.topk_length != nullptr;
11

12
13
14
    // Dispatch based on d_qk dimension and presence of topk_length
    if (params.d_qk == 512) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
15
            gfx93::fwd::run_fwd_phase1_kernel<512, true>(params);
16
        } else {
zhanghj2's avatar
zhanghj2 committed
17
            gfx93::fwd::run_fwd_phase1_kernel<512, false>(params);
18
        }
19
20
    } else if (params.d_qk == 576) {
        if (have_topk_length) {
zhanghj2's avatar
zhanghj2 committed
21
            gfx93::fwd::run_fwd_phase1_kernel<576, true>(params);
22
        } else {
zhanghj2's avatar
zhanghj2 committed
23
            gfx93::fwd::run_fwd_phase1_kernel<576, false>(params);
24
25
        }
    } else {
26
        throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
27
28
29
    }
}

zhanghj2's avatar
zhanghj2 committed
30
}  // namespace gfx93