fwd.cu 851 Bytes
Newer Older
1
2
#include "fwd.h"

3
#include <stdexcept>
4

5
#include "phase1.h"
6
7
8

namespace sm90 {

9
10
void run_fwd_kernel(const SparseAttnFwdParams& params) {
    const bool have_topk_length = params.topk_length != nullptr;
11

12
13
14
15
16
17
    // Dispatch based on d_qk dimension and presence of topk_length
    if (params.d_qk == 512) {
        if (have_topk_length) {
            sm90::fwd::run_fwd_phase1_kernel<512, true>(params);
        } else {
            sm90::fwd::run_fwd_phase1_kernel<512, false>(params);
18
        }
19
20
21
    } else if (params.d_qk == 576) {
        if (have_topk_length) {
            sm90::fwd::run_fwd_phase1_kernel<576, true>(params);
22
        } else {
23
            sm90::fwd::run_fwd_phase1_kernel<576, false>(params);
24
25
        }
    } else {
26
        throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
27
28
29
    }
}

30
}  // namespace sm90