splitkv_mla.cuh 1.58 KB
Newer Older
1
2
3
#include <cutlass/cutlass.h>

#include "utils.h"
4
5

#include "params.h"
6
7
8
9
10
#include "config.h"
#include "traits.h"

using namespace cute;

11
12
namespace sm90 {

13
14
15
16
17
18
19
// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking
// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)
// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM
static constexpr float MAX_INIT_VAL_SM = -1e30f;
static constexpr float MAX_INIT_VAL = -1e33f;


zhanghj2's avatar
zhanghj2 committed
20
21
22
template<typename T>
__global__ void __launch_bounds__(T::NUM_THREADS, 1)
flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
23
24
25
26
27

}


template<typename InputT>
28
29
30
31
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
    FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
    FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);

32
33
    using T = Traits<InputT>;
    auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
zhanghj2's avatar
zhanghj2 committed
34
35
 
    auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T>;
36
37
38
39
    constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan);

    // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
    const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
zhanghj2's avatar
zhanghj2 committed
40
41
42
43
44
45
46
47
48
49

    // cudaLaunchConfig_t mla_kernel_config = {
    //     dim3(num_m_block, params.h_k, params.num_sm_parts),
    //     dim3(T::NUM_THREADS, 1, 1),
    //     smem_size,
    //     params.stream,
    //     mla_kernel_attributes,
    //     1
    // };
    // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
50
51
52
    CHECK_CUDA_KERNEL_LAUNCH();
}

53
}