fmha_api.cpp 30.8 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
/******************************************************************************
Tri Dao's avatar
Tri Dao committed
2
 * Copyright (c) 2022, Tri Dao.
Tri Dao's avatar
Tri Dao committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
31
#include <c10/cuda/CUDAGuard.h>
Tri Dao's avatar
Tri Dao committed
32
33
34

#include "fmha.h"

Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")


void set_params_fprop(FMHA_fprop_params &params,
                      // sizes
                      const size_t b,
                      const size_t seqlen_q,
                      const size_t seqlen_k,
                      const size_t h,
                      const size_t d,
                      // device pointers
                      const at::Tensor q,
                      const at::Tensor k,
                      const at::Tensor v,
49
                      at::Tensor out,
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
56
                      void *cu_seqlens_q_d,
                      void *cu_seqlens_k_d,
                      void *o_tmp_d,
                      void *s_d,
                      void *softmax_lse_d,
                      float p_dropout,
                      float softmax_scale,
Tri Dao's avatar
Tri Dao committed
57
58
                      bool is_causal,
                      int num_splits) {
Tri Dao's avatar
Tri Dao committed
59
60

    Data_type acc_type = DATA_TYPE_FP32;
Tri Dao's avatar
Tri Dao committed
61
    Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
Tri Dao's avatar
Tri Dao committed
62
63
64
65

    // Reset the parameters
    memset(&params, 0, sizeof(params));

Tri Dao's avatar
Tri Dao committed
66
67
    params.is_bf16 = q.dtype() == torch::kBFloat16;

Tri Dao's avatar
Tri Dao committed
68
    // Set the pointers and strides.
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
    params.q_ptr = q.data_ptr();
    params.k_ptr = k.data_ptr();
    params.v_ptr = v.data_ptr();
    params.q_row_stride_in_elts = q.stride(0);
    params.k_row_stride_in_elts = k.stride(0);
    params.v_row_stride_in_elts = v.stride(0);
    params.q_head_stride_in_elts = q.stride(1);
    params.k_head_stride_in_elts = k.stride(1);
    params.v_head_stride_in_elts = v.stride(1);
78
79
80
    params.o_ptr = out.data_ptr();
    params.o_row_stride_in_elts = out.stride(0);
    params.o_head_stride_in_elts = out.stride(1);
Tri Dao's avatar
Tri Dao committed
81
    params.o_tmp_ptr = o_tmp_d;
82
83
    params.o_tmp_row_stride_in_elts = h * d;
    params.o_tmp_head_stride_in_elts = d;
Tri Dao's avatar
Tri Dao committed
84

Tri Dao's avatar
Tri Dao committed
85
86
    params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
    params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
Tri Dao's avatar
Tri Dao committed
87
88
89

    // S = softmax(P)
    params.s_ptr = s_d;
Tri Dao's avatar
Tri Dao committed
90
    params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type);
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97

    // Softmax sum
    params.softmax_lse_ptr = softmax_lse_d;

    // Set the dimensions.
    params.b = b;
    params.h = h;
Tri Dao's avatar
Tri Dao committed
98
99
    params.seqlen_q = seqlen_q;
    params.seqlen_k = seqlen_k;
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
106
107
108
109
110
111
    params.d = d;

    // Set the different scale values.
    // const float scale_bmm1 = 1.f / sqrtf(d);
    const float scale_bmm1 = softmax_scale;

    params.scale_bmm1f = scale_bmm1;
    set_alpha(params.scale_bmm1, scale_bmm1, data_type);

    // Set this to probability of keeping an element to simplify things.
    params.p_dropout = 1.f - p_dropout;
    // Convert p from float to int so we don't have to convert the random uint to float to compare.
112
    // [Minor] We want to round down since when we do the comparison we use <= instead of <
Tri Dao's avatar
Tri Dao committed
113
114
115
    params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
    params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
    params.rp_dropout = 1.f / params.p_dropout;
116
    params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
Tri Dao's avatar
Tri Dao committed
117
118
119
120
    TORCH_CHECK(p_dropout < 1.f);
    set_alpha(params.scale_dropout, params.rp_dropout, data_type);

    params.is_causal = is_causal;
Tri Dao's avatar
Tri Dao committed
121
    params.num_splits = num_splits;
Tri Dao's avatar
Tri Dao committed
122
123
}

Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
129
130
131
132
133
134
void set_params_dgrad(FMHA_dgrad_params &params,
                      // sizes
                      const size_t b,
                      const size_t seqlen_q,
                      const size_t seqlen_k,
                      const size_t h,
                      const size_t d,
                      // device pointers
                      const at::Tensor q,
                      const at::Tensor k,
                      const at::Tensor v,
135
                      const at::Tensor out,
Tri Dao's avatar
Tri Dao committed
136
137
138
139
140
141
142
143
144
145
146
                      at::Tensor dq,
                      at::Tensor dk,
                      at::Tensor dv,
                      void *cu_seqlens_q_d,
                      void *cu_seqlens_k_d,
                      void *dq_tmp_d,
                      void *do_packed_d,
                      void *softmax_lse_d,
                      void *dsoftmax_sum_d,
                      float p_dropout,
                      float softmax_scale,
Tri Dao's avatar
Tri Dao committed
147
148
                      bool is_causal,
                      int num_splits) {
Tri Dao's avatar
Tri Dao committed
149
150
151

    set_params_fprop(params,
                     b, seqlen_q, seqlen_k, h, d,
152
                     q, k, v, out,
Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
158
159
                     cu_seqlens_q_d,
                     cu_seqlens_k_d,
                     dq_tmp_d,  // Reusing the o_tmp_ptr variable to store dq_tmp
                     nullptr,
                     softmax_lse_d,
                     p_dropout,
                     softmax_scale,
Tri Dao's avatar
Tri Dao committed
160
161
                     is_causal,
                     num_splits);
Tri Dao's avatar
Tri Dao committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    // Set the pointers and strides.
    params.dq_ptr = dq.data_ptr();
    params.dk_ptr = dk.data_ptr();
    params.dv_ptr = dv.data_ptr();
    params.dq_row_stride_in_elts = dq.stride(0);
    params.dk_row_stride_in_elts = dk.stride(0);
    params.dv_row_stride_in_elts = dv.stride(0);
    params.dq_head_stride_in_elts = dq.stride(1);
    params.dk_head_stride_in_elts = dk.stride(1);
    params.dv_head_stride_in_elts = dv.stride(1);
    params.do_ptr = do_packed_d;

    // Softmax sum
    params.dsoftmax_sum = dsoftmax_sum_d;
}

179
std::vector<at::Tensor>
Tri Dao's avatar
Tri Dao committed
180
181
182
mha_fwd(const at::Tensor &q,         // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
        const at::Tensor &k,         // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
        const at::Tensor &v,         // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
183
        at::Tensor &out,             // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Tri Dao's avatar
Tri Dao committed
184
185
186
187
        const at::Tensor &cu_seqlens_q,  // b+1
        const at::Tensor &cu_seqlens_k,  // b+1
        const int max_seqlen_q_,
        const int max_seqlen_k_,
Tri Dao's avatar
Tri Dao committed
188
189
190
191
192
        const float p_dropout,
        const float softmax_scale,
        const bool zero_tensors,
        const bool is_causal,
        const bool return_softmax,
Tri Dao's avatar
Tri Dao committed
193
        const int num_splits,
Tri Dao's avatar
Tri Dao committed
194
195
196
        c10::optional<at::Generator> gen_) {

    auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
197
    bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Tri Dao's avatar
Tri Dao committed
198
    bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
Tri Dao's avatar
Tri Dao committed
199
200
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
    TORCH_CHECK(is_sm8x || is_sm75);
Tri Dao's avatar
Tri Dao committed
201
202
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    bool is_dropout = p_dropout > 0.0;
Tri Dao's avatar
Tri Dao committed
203
204
    Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);

Tri Dao's avatar
Tri Dao committed
205
206
207
208
    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
    TORCH_CHECK(k.dtype() == q_dtype);
    TORCH_CHECK(v.dtype() == q_dtype);
209
    TORCH_CHECK(out.dtype() == q_dtype);
Tri Dao's avatar
Tri Dao committed
210
211
212
213
214
215
    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);

    TORCH_CHECK(q.is_cuda());
    TORCH_CHECK(k.is_cuda());
    TORCH_CHECK(v.is_cuda());
216
    TORCH_CHECK(out.is_cuda());
Tri Dao's avatar
Tri Dao committed
217
218
219
220
221
222
    TORCH_CHECK(cu_seqlens_q.is_cuda());
    TORCH_CHECK(cu_seqlens_k.is_cuda());

    TORCH_CHECK(q.stride(-1) == 1);
    TORCH_CHECK(k.stride(-1) == 1);
    TORCH_CHECK(v.stride(-1) == 1);
223
    TORCH_CHECK(out.stride(-1) == 1);
YangShu's avatar
YangShu committed
224
    TORCH_CHECK(cu_seqlens_q.is_contiguous());
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
    TORCH_CHECK(cu_seqlens_k.is_contiguous());

    const auto sizes = q.sizes();

    const int batch_size = cu_seqlens_q.numel() - 1;
    const int total_q = sizes[TOTAL_DIM];
Tri Dao's avatar
Tri Dao committed
231
232
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
Tri Dao's avatar
Tri Dao committed
233
    const int total_k = k.size(TOTAL_DIM);
Tri Dao's avatar
Tri Dao committed
234
    TORCH_CHECK(batch_size > 0);
235
    TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
Tri Dao's avatar
Tri Dao committed
236

Tri Dao's avatar
Tri Dao committed
237
238
239
    CHECK_SHAPE(q, total_q, num_heads, head_size);
    CHECK_SHAPE(k, total_k, num_heads, head_size);
    CHECK_SHAPE(v, total_k, num_heads, head_size);
240
    CHECK_SHAPE(out, total_q, num_heads, head_size);
Tri Dao's avatar
Tri Dao committed
241
242
243
    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

244
    int blocksize_c = head_size > 64 ? 128 : 256;
Tri Dao's avatar
Tri Dao committed
245
246
247
248
249
250
    // Need to round max_seqlen_k to multiples of blocksize_c
    int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
    if( max_seqlen_k_ <= 128 ) {
        max_seqlen_k = 128;
    } else if( max_seqlen_k_ <= 256 ) {
        max_seqlen_k = 256;
Tri Dao's avatar
Tri Dao committed
251
    }
Tri Dao's avatar
Tri Dao committed
252
253
    int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
    bool loop = max_seqlen_k > blocksize_c;
Tri Dao's avatar
Tri Dao committed
254

255
    // Otherwise the kernel will be launched from cuda:0 device
256
257
    // Cast to char to avoid compiler warning about narrowing
    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
258

Tri Dao's avatar
Tri Dao committed
259
    auto opts = q.options();
Tri Dao's avatar
Tri Dao committed
260

261
    // auto o = torch::empty({ total_q, num_heads, head_size }, opts);
Tri Dao's avatar
Tri Dao committed
262
263

    at::Tensor o_tmp;
Tri Dao's avatar
Tri Dao committed
264
    if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
Tri Dao's avatar
Tri Dao committed
265

Tri Dao's avatar
Tri Dao committed
266
267
    auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
    // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
268
269

    at::Tensor s;
Tri Dao's avatar
Tri Dao committed
270
    if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
Tri Dao's avatar
Tri Dao committed
271
272

    if( zero_tensors ) {
273
        out.zero_();
Tri Dao's avatar
Tri Dao committed
274
275
276
277
278
279
280
        softmax_lse.fill_(-std::numeric_limits<float>::infinity());
        if (return_softmax) {s.zero_();}
    }

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

Tri Dao's avatar
Tri Dao committed
281
282
283
284
285
286
    set_params_fprop(launch_params.params,
                     batch_size,
                     max_seqlen_q,
                     max_seqlen_k,
                     num_heads,
                     head_size,
287
                     q, k, v, out,
Tri Dao's avatar
Tri Dao committed
288
289
290
291
292
293
294
                     cu_seqlens_q.data_ptr(),
                     cu_seqlens_k.data_ptr(),
                     loop ? o_tmp.data_ptr() : nullptr,
                     return_softmax ? s.data_ptr() : nullptr,
                     softmax_lse.data_ptr(),
                     p_dropout,
                     softmax_scale,
Tri Dao's avatar
Tri Dao committed
295
296
                     is_causal,
                     num_splits);
Tri Dao's avatar
Tri Dao committed
297
298
299

    // number of times random will be generated per thread, to offset philox counter in thc random
    // state
Tri Dao's avatar
Tri Dao committed
300
301
    // We use a custom RNG that increases the offset by batch_size * nheads * 32.
    int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
Tri Dao's avatar
Tri Dao committed
302
303
304
305
306
307
308
309
    at::PhiloxCudaState rng_engine_inputs;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
    }

310
    run_fmha_fp16_sm80(launch_params);
Tri Dao's avatar
Tri Dao committed
311

312
    std::vector<at::Tensor> result = {softmax_lse};
Tri Dao's avatar
Tri Dao committed
313
314
315
316
317
318
    if (return_softmax) {result.push_back(s);}
    return result;
}


std::vector<at::Tensor>
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
mha_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
        const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
        const at::Tensor &k,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
        const at::Tensor &v,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
        const at::Tensor &out,   // total_q x num_heads x head_size
Tri Dao's avatar
Tri Dao committed
324
        const at::Tensor &softmax_lse_,     // b x h x s softmax logsumexp
Tri Dao's avatar
Tri Dao committed
325
326
327
328
329
330
331
        at::Tensor &dq,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
        at::Tensor &dk,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
        at::Tensor &dv,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
        const at::Tensor &cu_seqlens_q,  // b+1
        const at::Tensor &cu_seqlens_k,  // b+1
        const int max_seqlen_q_,
        const int max_seqlen_k_,          // max sequence length to choose the kernel
Tri Dao's avatar
Tri Dao committed
332
333
334
335
        const float p_dropout,         // probability to drop
        const float softmax_scale,
        const bool zero_tensors,
        const bool is_causal,
Tri Dao's avatar
Tri Dao committed
336
        const int num_splits,
Tri Dao's avatar
Tri Dao committed
337
338
339
        c10::optional<at::Generator> gen_
) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
340
    bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Tri Dao's avatar
Tri Dao committed
341
    bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
Tri Dao's avatar
Tri Dao committed
342
343
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
    TORCH_CHECK(is_sm8x || is_sm75);
Tri Dao's avatar
Tri Dao committed
344
345
346
347
348
    auto launch = &run_fmha_dgrad_fp16_sm80;

    bool is_dropout = p_dropout > 0.0;
    auto stream = at::cuda::getCurrentCUDAStream().stream();

Tri Dao's avatar
Tri Dao committed
349
350
351
352
353
354
355
356
357
    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
    TORCH_CHECK(k.dtype() == q_dtype);
    TORCH_CHECK(v.dtype() == q_dtype);
    TORCH_CHECK(out.dtype() == q_dtype);
    TORCH_CHECK(dout.dtype() == q_dtype);
    TORCH_CHECK(dq.dtype() == q_dtype);
    TORCH_CHECK(dk.dtype() == q_dtype);
    TORCH_CHECK(dv.dtype() == q_dtype);
Tri Dao's avatar
Tri Dao committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);

    TORCH_CHECK(q.is_cuda());
    TORCH_CHECK(k.is_cuda());
    TORCH_CHECK(v.is_cuda());
    TORCH_CHECK(out.is_cuda());
    TORCH_CHECK(dout.is_cuda());
    TORCH_CHECK(softmax_lse_.is_cuda());
    TORCH_CHECK(cu_seqlens_q.is_cuda());
    TORCH_CHECK(cu_seqlens_k.is_cuda());

    TORCH_CHECK(q.stride(-1) == 1);
    TORCH_CHECK(k.stride(-1) == 1);
    TORCH_CHECK(v.stride(-1) == 1);
    TORCH_CHECK(out.is_contiguous());
    TORCH_CHECK(dout.is_contiguous());
    TORCH_CHECK(dq.stride(-1) == 1);
    TORCH_CHECK(dk.stride(-1) == 1);
    TORCH_CHECK(dv.stride(-1) == 1);
    TORCH_CHECK(cu_seqlens_q.is_contiguous());
    TORCH_CHECK(cu_seqlens_k.is_contiguous());

    const auto sizes = q.sizes();

    const int batch_size = cu_seqlens_q.numel() - 1;
    const int total_q = sizes[TOTAL_DIM];
Tri Dao's avatar
Tri Dao committed
385
386
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
Tri Dao's avatar
Tri Dao committed
387
    const int total_k = k.size(TOTAL_DIM);
Tri Dao's avatar
Tri Dao committed
388
    TORCH_CHECK(batch_size > 0);
389
390
    TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
    if (head_size > 64) {  // TODO: eventually we should support SM86 and SM70 with d=128 as well
Tri Dao's avatar
Tri Dao committed
391
392
        TORCH_CHECK(is_sm80);
    }
Tri Dao's avatar
Tri Dao committed
393

Tri Dao's avatar
Tri Dao committed
394
395
396
397
398
399
400
401
402
403
404
    CHECK_SHAPE(q, total_q, num_heads, head_size);
    CHECK_SHAPE(k, total_k, num_heads, head_size);
    CHECK_SHAPE(v, total_k, num_heads, head_size);
    CHECK_SHAPE(out, total_q, num_heads, head_size);
    CHECK_SHAPE(dout, total_q, num_heads, head_size);
    CHECK_SHAPE(dq, total_q, num_heads, head_size);
    CHECK_SHAPE(dk, total_k, num_heads, head_size);
    CHECK_SHAPE(dv, total_k, num_heads, head_size);
    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

405
    int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256;
Tri Dao's avatar
Tri Dao committed
406
407
408
409
410
    int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
    if( max_seqlen_k_ <= 128 ) {
        max_seqlen_k = 128;
    } else if( max_seqlen_k_ <= 256 ) {
        max_seqlen_k = 256;
Tri Dao's avatar
Tri Dao committed
411
    }
Tri Dao's avatar
Tri Dao committed
412
413
    int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
    bool loop = max_seqlen_k > blocksize_c;
Tri Dao's avatar
Tri Dao committed
414

415
    // Otherwise the kernel will be launched from cuda:0 device
416
417
    // Cast to char to avoid compiler warning about narrowing
    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
418

Tri Dao's avatar
Tri Dao committed
419
420
    // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
    auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
Tri Dao's avatar
Tri Dao committed
421

Tri Dao's avatar
Tri Dao committed
422
423
    auto opts = q.options();
    auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
424
    at::Tensor dq_tmp;
Tri Dao's avatar
Tri Dao committed
425
    if (loop) { dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
Tri Dao's avatar
Tri Dao committed
426
427

    if( zero_tensors ) {
Tri Dao's avatar
Tri Dao committed
428
429
430
        dq.zero_();
        dk.zero_();
        dv.zero_();
Tri Dao's avatar
Tri Dao committed
431
432
433
        softmax_d.zero_();
    }

Tri Dao's avatar
Tri Dao committed
434
435
436
437
438
439
440
441
    FMHA_dgrad_params params;

    set_params_dgrad(params,
                     batch_size,
                     max_seqlen_q,
                     max_seqlen_k,
                     num_heads,
                     head_size,
442
                     q, k, v, out,
Tri Dao's avatar
Tri Dao committed
443
444
445
446
447
448
449
450
451
                     dq, dk, dv,
                     cu_seqlens_q.data_ptr(),
                     cu_seqlens_k.data_ptr(),
                     loop ? dq_tmp.data_ptr() : nullptr,
                     dout.data_ptr(),
                     softmax_lse.data_ptr(),
                     softmax_d.data_ptr(),
                     p_dropout,
                     softmax_scale,
Tri Dao's avatar
Tri Dao committed
452
                     is_causal,
Tri Dao's avatar
Tri Dao committed
453
454
455
456
457
                     num_splits);

    launch(params, stream, /*configure=*/true);

    if (params.num_splits > 1) {
458
459
460
461
462
463
        if (!dq_tmp.defined()) {
            dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
            params.o_tmp_ptr = dq_tmp.data_ptr();  // o_tmp stores dq_tmp in the backward pass
        } else {
            dq_tmp.zero_();
        }
Tri Dao's avatar
Tri Dao committed
464
    }
Tri Dao's avatar
Tri Dao committed
465
466
467
468

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

469
470
    // We use a custom RNG that increases the offset by batch_size * nheads * 32.
    int64_t counter_offset = params.b * params.h * 32;
Tri Dao's avatar
Tri Dao committed
471
472
473
474
475
476
477

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        params.philox_args = gen->philox_cuda_state(counter_offset);
    }

Tri Dao's avatar
Tri Dao committed
478
479
    launch(params, stream, /*configure=*/false);

480
481
482
483
    if (params.num_splits > 1) {
        dq.copy_(dq_tmp);
    }

Tri Dao's avatar
Tri Dao committed
484
    return { dq, dk, dv, softmax_d };
Tri Dao's avatar
Tri Dao committed
485
486
487
}

std::vector<at::Tensor>
Tri Dao's avatar
Tri Dao committed
488
489
490
491
492
mha_fwd_block(const at::Tensor &q,         // total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
              const at::Tensor &k,         // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              const at::Tensor &v,         // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              const at::Tensor &cu_seqlens_q,  // b+1
              const at::Tensor &cu_seqlens_k,  // b+1
Tri Dao's avatar
Tri Dao committed
493
              const at::Tensor &blockmask,   // (seqlen / 256, seqlen / 16)
Tri Dao's avatar
Tri Dao committed
494
495
              const int max_seqlen_q_,
              const int max_seqlen_k_,
Tri Dao's avatar
Tri Dao committed
496
497
498
499
500
501
502
503
504
505
              const float p_dropout,
              const float softmax_scale,
              const bool is_causal,
              const bool return_softmax,
              c10::optional<at::Generator> gen_) {

    auto dprops = at::cuda::getCurrentDeviceProperties();
    TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    bool is_dropout = p_dropout > 0.0;
Tri Dao's avatar
Tri Dao committed
506
    Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
Tri Dao's avatar
Tri Dao committed
507

Tri Dao's avatar
Tri Dao committed
508
509
510
511
512
513
    TORCH_CHECK(q.dtype() == torch::kFloat16);
    TORCH_CHECK(k.dtype() == torch::kFloat16);
    TORCH_CHECK(v.dtype() == torch::kFloat16);
    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
    TORCH_CHECK(blockmask.dtype() == torch::kInt32);
Tri Dao's avatar
Tri Dao committed
514

Tri Dao's avatar
Tri Dao committed
515
516
517
518
519
    TORCH_CHECK(q.is_cuda());
    TORCH_CHECK(k.is_cuda());
    TORCH_CHECK(v.is_cuda());
    TORCH_CHECK(cu_seqlens_q.is_cuda());
    TORCH_CHECK(cu_seqlens_k.is_cuda());
Tri Dao's avatar
Tri Dao committed
520
521
    TORCH_CHECK(blockmask.is_cuda())

Tri Dao's avatar
Tri Dao committed
522
523
524
525
526
    TORCH_CHECK(q.stride(-1) == 1);
    TORCH_CHECK(k.stride(-1) == 1);
    TORCH_CHECK(v.stride(-1) == 1);
    TORCH_CHECK(cu_seqlens_k.is_contiguous());
    TORCH_CHECK(cu_seqlens_k.is_contiguous());
Tri Dao's avatar
Tri Dao committed
527
528
    TORCH_CHECK(blockmask.is_contiguous())

Tri Dao's avatar
Tri Dao committed
529
    const auto sizes = q.sizes();
Tri Dao's avatar
Tri Dao committed
530

Tri Dao's avatar
Tri Dao committed
531
532
    const int batch_size = cu_seqlens_q.numel() - 1;
    const int total_q = sizes[TOTAL_DIM];
Tri Dao's avatar
Tri Dao committed
533
534
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
Tri Dao's avatar
Tri Dao committed
535
    const int total_k = k.size(TOTAL_DIM);
Tri Dao's avatar
Tri Dao committed
536
    TORCH_CHECK(batch_size > 0);
Tri Dao's avatar
Tri Dao committed
537
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
Tri Dao's avatar
Tri Dao committed
538

Tri Dao's avatar
Tri Dao committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    CHECK_SHAPE(q, total_q, num_heads, head_size);
    CHECK_SHAPE(k, total_k, num_heads, head_size);
    CHECK_SHAPE(v, total_k, num_heads, head_size);
    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

    int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
    if( max_seqlen_k <= 256 ) {
        max_seqlen_k = 256;
    }
    int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
    bool loop = max_seqlen_k > 256;
    CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16);

    auto opts = q.options();

    auto o = torch::zeros({ total_q, num_heads, head_size }, opts);
Tri Dao's avatar
Tri Dao committed
556
557
558
559

    at::Tensor o_tmp;
    if (loop) {
        // o_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
560
        o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
561
562
    }

Tri Dao's avatar
Tri Dao committed
563
564
    // auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
    auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
565
566
567

    at::Tensor s;
    if (return_softmax) {
Tri Dao's avatar
Tri Dao committed
568
        s = torch::zeros({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts);
Tri Dao's avatar
Tri Dao committed
569
570
571
572
573
    }

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

Tri Dao's avatar
Tri Dao committed
574
575
576
577
578
579
    set_params_fprop(launch_params.params,
                     batch_size,
                     max_seqlen_q,
                     max_seqlen_k,
                     num_heads,
                     head_size,
580
                     q, k, v, o,
Tri Dao's avatar
Tri Dao committed
581
582
583
584
585
586
587
                     cu_seqlens_q.data_ptr(),
                     cu_seqlens_k.data_ptr(),
                     loop ? o_tmp.data_ptr() : nullptr,
                     return_softmax ? s.data_ptr() : nullptr,
                     softmax_lse.data_ptr(),
                     p_dropout,
                     softmax_scale,
Tri Dao's avatar
Tri Dao committed
588
589
                     is_causal,
                     /*num_splits=*/1);
Tri Dao's avatar
Tri Dao committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());

    run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
    // number of times random will be generated per thread, to offset philox counter in thc random
    // state
    int64_t counter_offset = launch_params.elts_per_thread;
    at::PhiloxCudaState rng_engine_inputs;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    run_fmha_block_fp16_sm80(launch_params, /*configure=*/false);

Tri Dao's avatar
Tri Dao committed
606
    std::vector<at::Tensor> result = {o, softmax_lse};
Tri Dao's avatar
Tri Dao committed
607
608
609
610
611
612
    if (return_softmax) {result.push_back(s);}
    return result;
}

std::vector<at::Tensor>
mha_bwd_block(const at::Tensor &dout,  // total x num_heads, x head_size
Tri Dao's avatar
Tri Dao committed
613
614
615
616
617
618
619
620
621
622
              const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
              const at::Tensor &k,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              const at::Tensor &v,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              const at::Tensor &out,   // total_q x num_heads x head_size
              const at::Tensor &softmax_lse_,     // b x h x s softmax logsumexp
              at::Tensor &dq,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
              at::Tensor &dk,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              at::Tensor &dv,   // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
              const at::Tensor &cu_seqlens_q,  // b+1
              const at::Tensor &cu_seqlens_k,  // b+1
Tri Dao's avatar
Tri Dao committed
623
              const at::Tensor &blockmask,   // (seqlen / 256, seqlen / 16)
Tri Dao's avatar
Tri Dao committed
624
625
              const int max_seqlen_q_,
              const int max_seqlen_k_,          // max sequence length to choose the kernel
Tri Dao's avatar
Tri Dao committed
626
627
628
629
630
631
              const float p_dropout,         // probability to drop
              const float softmax_scale,
              const bool is_causal,
              c10::optional<at::Generator> gen_
) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
632
633
    bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
Tri Dao's avatar
Tri Dao committed
634
635
636
637
638
639
    TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
    auto launch = &run_fmha_block_dgrad_fp16_sm80;

    bool is_dropout = p_dropout > 0.0;
    auto stream = at::cuda::getCurrentCUDAStream().stream();

Tri Dao's avatar
Tri Dao committed
640
641
642
643
    TORCH_CHECK(q.dtype() == torch::kFloat16);
    TORCH_CHECK(k.dtype() == torch::kFloat16);
    TORCH_CHECK(v.dtype() == torch::kFloat16);
    TORCH_CHECK(out.dtype() == torch::kFloat16);
Tri Dao's avatar
Tri Dao committed
644
    TORCH_CHECK(dout.dtype() == torch::kFloat16);
Tri Dao's avatar
Tri Dao committed
645
646
647
648
649
    TORCH_CHECK(dq.dtype() == torch::kFloat16);
    TORCH_CHECK(dk.dtype() == torch::kFloat16);
    TORCH_CHECK(dv.dtype() == torch::kFloat16);
    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
Tri Dao's avatar
Tri Dao committed
650
651
    TORCH_CHECK(blockmask.dtype() == torch::kInt32);

Tri Dao's avatar
Tri Dao committed
652
653
654
655
656
657
658
659
    TORCH_CHECK(q.is_cuda());
    TORCH_CHECK(k.is_cuda());
    TORCH_CHECK(v.is_cuda());
    TORCH_CHECK(out.is_cuda());
    TORCH_CHECK(dout.is_cuda());
    TORCH_CHECK(softmax_lse_.is_cuda());
    TORCH_CHECK(cu_seqlens_q.is_cuda());
    TORCH_CHECK(cu_seqlens_k.is_cuda());
Tri Dao's avatar
Tri Dao committed
660
661
    TORCH_CHECK(blockmask.is_cuda());

Tri Dao's avatar
Tri Dao committed
662
663
664
665
666
667
668
669
670
671
    TORCH_CHECK(q.stride(-1) == 1);
    TORCH_CHECK(k.stride(-1) == 1);
    TORCH_CHECK(v.stride(-1) == 1);
    TORCH_CHECK(out.is_contiguous());
    TORCH_CHECK(dout.is_contiguous());
    TORCH_CHECK(dq.stride(-1) == 1);
    TORCH_CHECK(dk.stride(-1) == 1);
    TORCH_CHECK(dv.stride(-1) == 1);
    TORCH_CHECK(cu_seqlens_q.is_contiguous());
    TORCH_CHECK(cu_seqlens_k.is_contiguous());
Tri Dao's avatar
Tri Dao committed
672
673
    TORCH_CHECK(blockmask.is_contiguous());

Tri Dao's avatar
Tri Dao committed
674
    const auto sizes = q.sizes();
Tri Dao's avatar
Tri Dao committed
675

Tri Dao's avatar
Tri Dao committed
676
677
    const int batch_size = cu_seqlens_q.numel() - 1;
    const int total_q = sizes[TOTAL_DIM];
Tri Dao's avatar
Tri Dao committed
678
679
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
Tri Dao's avatar
Tri Dao committed
680
    const int total_k = k.size(TOTAL_DIM);
Tri Dao's avatar
Tri Dao committed
681
    TORCH_CHECK(batch_size > 0);
Tri Dao's avatar
Tri Dao committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
    if (head_size == 128) {  // TODO: eventually we should support SM86 and SM70 with d=128 as well
        TORCH_CHECK(is_sm80);
    }

    CHECK_SHAPE(q, total_q, num_heads, head_size);
    CHECK_SHAPE(k, total_k, num_heads, head_size);
    CHECK_SHAPE(v, total_k, num_heads, head_size);
    CHECK_SHAPE(out, total_q, num_heads, head_size);
    CHECK_SHAPE(dout, total_q, num_heads, head_size);
    CHECK_SHAPE(dq, total_q, num_heads, head_size);
    CHECK_SHAPE(dk, total_k, num_heads, head_size);
    CHECK_SHAPE(dv, total_k, num_heads, head_size);
    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

    int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
    if( max_seqlen_k <= 256 ) {
        max_seqlen_k = 256;
    }
    int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
    bool loop = max_seqlen_k > 256;
    CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16);
Tri Dao's avatar
Tri Dao committed
705

Tri Dao's avatar
Tri Dao committed
706
707
708
709
710
    // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
    auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();

    auto opts = q.options();
    auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
711
712
713
    at::Tensor dq_tmp;
    if (loop) {
        // dq_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
714
        dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
Tri Dao's avatar
Tri Dao committed
715
716
    }

Tri Dao's avatar
Tri Dao committed
717
718
719
720
721
722
723
724
    FMHA_dgrad_params params;

    set_params_dgrad(params,
                     batch_size,
                     max_seqlen_q,
                     max_seqlen_k,
                     num_heads,
                     head_size,
725
                     q, k, v, out,
Tri Dao's avatar
Tri Dao committed
726
727
728
729
730
731
732
733
734
                     dq, dk, dv,
                     cu_seqlens_q.data_ptr(),
                     cu_seqlens_k.data_ptr(),
                     loop ? dq_tmp.data_ptr() : nullptr,
                     dout.data_ptr(),
                     softmax_lse.data_ptr(),
                     softmax_d.data_ptr(),
                     p_dropout,
                     softmax_scale,
Tri Dao's avatar
Tri Dao committed
735
736
                     is_causal,
                     /*num_splits=*/1);
Tri Dao's avatar
Tri Dao committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    params.blockmask = static_cast<int *>(blockmask.data_ptr());

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

    // We're gonna reset the rng state in Python after this kernel, so the counter offset
    // here doesn't matter at all. We just choose an arbitrary number;
    int64_t counter_offset = 4;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    launch(params, stream);
Tri Dao's avatar
Tri Dao committed
753
    return { dq, dk, dv, softmax_d };
Tri Dao's avatar
Tri Dao committed
754
755
756
757
758
759
760
761
762
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "Fused Multi-head Self-attention";
    m.def("fwd", &mha_fwd, "Forward pass");
    m.def("bwd", &mha_bwd, "Backward pass");
    m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
    m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
Tri Dao's avatar
Tri Dao committed
763
}