flash_api.cpp 9.48 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
ljss's avatar
ljss committed
2
3
4
/******************************************************************************
 * Copyright (c) 2024, Tri Dao.
 ******************************************************************************/
Jiashi Li's avatar
Jiashi Li committed
5
6
7
8
9
10
11
12

#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cutlass/fast_math.h>

13
14
15
16
17
#include "kernels/config.h"
#include "kernels/get_mla_metadata.h"
#include "kernels/mla_combine.h"
#include "kernels/params.h"
#include "kernels/splitkv_mla.h"
Jiashi Li's avatar
Jiashi Li committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

std::vector<at::Tensor>
get_mla_metadata(
    at::Tensor &seqlens_k,
    const int num_heads_per_head_k,
    const int num_heads_k
) {
    CHECK_DEVICE(seqlens_k);
    TORCH_CHECK(seqlens_k.is_contiguous());
    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);

    int batch_size = seqlens_k.size(0);
    int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
    auto options = seqlens_k.options();

    auto dprops = at::cuda::getCurrentDeviceProperties();
    int sm_count = dprops->multiProcessorCount;
39
    int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M);
Jiashi Li's avatar
Jiashi Li committed
40
41
42
43
44
45
46
47
48
49
50
51
52

    auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
    auto num_splits = torch::empty({batch_size + 1}, options);
    int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
    int *num_splits_ptr = num_splits.data_ptr<int>();

    at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    Mla_metadata_params params = {};
    params.seqlens_k_ptr = seqlens_k_ptr;
    params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
    params.num_splits_ptr = num_splits_ptr;
    params.batch_size = batch_size;
53
54
    params.block_size_n = Config::PAGE_BLOCK_SIZE;
    params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS;
Jiashi Li's avatar
Jiashi Li committed
55
    params.num_sm_parts = num_sm_parts;
56
    run_get_mla_metadata_kernel(params, stream);
Jiashi Li's avatar
Jiashi Li committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    return {tile_scheduler_metadata, num_splits};
}

std::vector<at::Tensor>
mha_fwd_kvcache_mla(
    at::Tensor &q,                               // batch_size x seqlen_q x num_heads x head_size
    const at::Tensor &kcache,                    // num_blocks x page_block_size x num_heads_k x head_size
    const int head_size_v,
    const at::Tensor &seqlens_k,                 // batch_size
    const at::Tensor &block_table,               // batch_size x max_num_blocks_per_seq
    const float softmax_scale,
    bool is_causal,
    const at::Tensor &tile_scheduler_metadata,   // num_sm_parts x TileSchedulerMetaDataSize
    const at::Tensor &num_splits                 // batch_size + 1
) {
73
    // Check the architecture
Jiashi Li's avatar
Jiashi Li committed
74
75
76
77
    auto dprops = at::cuda::getCurrentDeviceProperties();
    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
    TORCH_CHECK(is_sm90);

78
    // Check data types
Jiashi Li's avatar
Jiashi Li committed
79
    auto q_dtype = q.dtype();
80
    TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
Jiashi Li's avatar
Jiashi Li committed
81
    TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
82
83
84
85
    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
    TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
    TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
    TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
Jiashi Li's avatar
Jiashi Li committed
86

87
88
89
90
91
92
93
    // Check device
    CHECK_DEVICE(q);
    CHECK_DEVICE(kcache);
    CHECK_DEVICE(seqlens_k);
    CHECK_DEVICE(block_table);
    CHECK_DEVICE(tile_scheduler_metadata);
    CHECK_DEVICE(num_splits);
Jiashi Li's avatar
Jiashi Li committed
94

95
    // Check layout
Jiashi Li's avatar
Jiashi Li committed
96
97
    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
    TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
98
    CHECK_CONTIGUOUS(seqlens_k);
Jiashi Li's avatar
Jiashi Li committed
99
    TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
100
101
    CHECK_CONTIGUOUS(tile_scheduler_metadata);
    CHECK_CONTIGUOUS(num_splits);
Jiashi Li's avatar
Jiashi Li committed
102
103
104
105

    const auto sizes = q.sizes();
    const int batch_size = sizes[0];
    const int seqlen_q_ori = sizes[1];
106
107
108
109
    const int num_heads_q = sizes[2];
    const int head_size_k = sizes[3];
    TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
    TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
Jiashi Li's avatar
Jiashi Li committed
110
111
112
113
114
115

    const int max_num_blocks_per_seq = block_table.size(1);
    const int num_blocks = kcache.size(0);
    const int page_block_size = kcache.size(1);
    const int num_heads_k = kcache.size(2);
    TORCH_CHECK(batch_size > 0, "batch size must be postive");
116
    TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
Jiashi Li's avatar
Jiashi Li committed
117
118
119

    if (seqlen_q_ori == 1) { is_causal = false; }

120
121
    const int num_q_heads_per_hk = num_heads_q / num_heads_k;
    const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
Jiashi Li's avatar
Jiashi Li committed
122
    const int num_heads = num_heads_k;
123
124
    q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
            .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
Jiashi Li's avatar
Jiashi Li committed
125

126
    CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
Jiashi Li's avatar
Jiashi Li committed
127
128
    CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
    CHECK_SHAPE(seqlens_k, batch_size);
129
130
131
    CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
    TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
    CHECK_SHAPE(num_splits, batch_size+1);
Jiashi Li's avatar
Jiashi Li committed
132
133
134
135

    at::cuda::CUDAGuard device_guard{(char)q.get_device()};

    auto opts = q.options();
136
137
138
    at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
    at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
    CHECK_CONTIGUOUS(softmax_lse);
Jiashi Li's avatar
Jiashi Li committed
139
140
141
142

    Flash_fwd_mla_params params = {};
    // Set the sizes.
    params.b = batch_size;
143
144
145
146
147
148
149
    params.s_q = seqlen_q_ori;
    params.q_seq_per_hk = q_seq_per_hk;
    params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
    params.h_q = num_heads_q;
    params.h_k = num_heads_k;
    params.num_blocks = num_blocks;
    params.q_head_per_hk = num_q_heads_per_hk;
Jiashi Li's avatar
Jiashi Li committed
150
    params.is_causal = is_causal;
151
    params.d = head_size_k;
Jiashi Li's avatar
Jiashi Li committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    params.d_v = head_size_v;
    params.scale_softmax = softmax_scale;
    params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
    // Set the pointers and strides.
    params.q_ptr = q.data_ptr();
    params.k_ptr = kcache.data_ptr();
    params.o_ptr = out.data_ptr();
    params.softmax_lse_ptr = softmax_lse.data_ptr();
    // All stride are in elements, not bytes.
    params.q_batch_stride = q.stride(0);
    params.k_batch_stride = kcache.stride(0);
    params.o_batch_stride = out.stride(0);
    params.q_row_stride = q.stride(-3);
    params.k_row_stride = kcache.stride(-3);
    params.o_row_stride = out.stride(-3);
    params.q_head_stride = q.stride(-2);
    params.k_head_stride = kcache.stride(-2);
    params.o_head_stride = out.stride(-2);

    params.block_table = block_table.data_ptr<int>();
    params.block_table_batch_stride = block_table.stride(0);
    params.page_block_size = page_block_size;
174
    
Jiashi Li's avatar
Jiashi Li committed
175
176
177
178
    params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
    params.num_sm_parts = tile_scheduler_metadata.size(0);
    params.num_splits_ptr = num_splits.data_ptr<int>();

179
180
181
182
183
184
    const int total_num_splits = batch_size + params.num_sm_parts;
    at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
    at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
    CHECK_CONTIGUOUS(softmax_lse_accum);
    CHECK_CONTIGUOUS(out_accum);
    params.total_num_splits = total_num_splits;
Jiashi Li's avatar
Jiashi Li committed
185
186
187
188
    params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
    params.oaccum_ptr = out_accum.data_ptr();

    auto stream = at::cuda::getCurrentCUDAStream().stream();
189
    TORCH_CHECK(head_size_k == 576);
Sijia Chen's avatar
Sijia Chen committed
190
    if (q_dtype == torch::kBFloat16) {
191
192
193
194
195
196
197
198
199
200
        run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
        run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
    } else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
        TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
        run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
        run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
    } else {
Sijia Chen's avatar
Sijia Chen committed
201
202
        TORCH_CHECK(false, "Unsupported tensor dtype for query");
    }
Jiashi Li's avatar
Jiashi Li committed
203

204
205
206
207
    out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
            .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
    softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
            .reshape({batch_size, num_heads_q, seqlen_q_ori});
Jiashi Li's avatar
Jiashi Li committed
208
209
210
211
212
213
214
215
216

    return {out, softmax_lse};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FlashMLA";
    m.def("get_mla_metadata", &get_mla_metadata);
    m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla);
}