phase1.cuh 1008 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#pragma once

#include "config.h"

#include "utils.h"
#include "../../helpers.h"

namespace sm90::fwd {

using namespace cute;

template<int D_QK, bool HAVE_TOPK_LENGTH>
zhanghj2's avatar
zhanghj2 committed
13
__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params) {
14
15
16

}

zhanghj2's avatar
zhanghj2 committed
17
18
19
20
template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
sparse_attn_fwd_kernel(const SparseAttnFwdParams params) {
    Kernel::devfunc(params);
21
22
23
24
25
26
27
28
29
30
}

template<int D_QK, bool HAVE_TOPK_LENGTH>
void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &params) {
    KU_ASSERT(params.h_kv == 1);
    KU_ASSERT(params.topk % (2*B_TOPK) == 0);   // To save some boundry checkings
    KU_ASSERT(params.topk > 0);
    KU_ASSERT(params.h_q % B_H == 0);

    auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
zhanghj2's avatar
zhanghj2 committed
31
    
32
33
34
35
36
37
38
39
}

template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
    KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
}

}