fmha_api.cpp 21.1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "fmha.h"

void set_params(Fused_multihead_attention_fprop_params &params,
                // sizes
                const size_t b,
                const size_t s,
                const size_t h,
                const size_t d,
                // device pointers
                void *qkv_packed_d,
                void *cu_seqlens_d,
                void *o_packed_d,
                void *o_tmp_d,
                void *do_packed_d,
                void *s_d,
                void *softmax_lse_d,
                void *dsoftmax_sum_d,
                float p_dropout,
                float softmax_scale,
                bool is_causal) {

    Data_type acc_type = DATA_TYPE_FP32;
    Data_type data_type = DATA_TYPE_FP16;

    // Reset the parameters
    memset(&params, 0, sizeof(params));

    // Set the pointers and strides.
59
60
61
62
63
64
65
66
67
    params.q_ptr = qkv_packed_d;
    params.k_ptr = qkv_packed_d + get_size_in_bytes(h * d, data_type);
    params.v_ptr = qkv_packed_d + 2 * get_size_in_bytes(h * d, data_type);
    params.q_row_stride_in_elts = 3 * h * d;
    params.k_row_stride_in_elts = 3 * h * d;
    params.v_row_stride_in_elts = 3 * h * d;
    params.q_head_stride_in_elts = d;
    params.k_head_stride_in_elts = d;
    params.v_head_stride_in_elts = d;
Tri Dao's avatar
Tri Dao committed
68
    params.o_ptr = o_packed_d;
69
70
    params.o_row_stride_in_elts = h * d;
    params.o_head_stride_in_elts = d;
Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    params.do_ptr = do_packed_d;
    params.o_tmp_ptr = o_tmp_d;

    params.cu_seqlens = static_cast<int *>(cu_seqlens_d);

    // S = softmax(P)
    params.s_ptr = s_d;
    params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);

    // Softmax sum
    params.softmax_lse_ptr = softmax_lse_d;
    params.dsoftmax_sum = dsoftmax_sum_d;

    // Set the dimensions.
    params.b = b;
    params.h = h;
    params.s = s;
    params.d = d;

    // Set the different scale values.
    // const float scale_bmm1 = 1.f / sqrtf(d);
    const float scale_bmm1 = softmax_scale;
    constexpr float scale_softmax = 1.f;
    constexpr float scale_bmm2 = 1.f;

    params.scale_bmm1f = scale_bmm1;
    set_alpha(params.scale_bmm1, scale_bmm1, data_type);
    set_alpha(params.scale_softmax, scale_softmax, acc_type);
    set_alpha(params.scale_bmm2, scale_bmm2, data_type);

    // Set this to probability of keeping an element to simplify things.
    params.p_dropout = 1.f - p_dropout;
    // Convert p from float to int so we don't have to convert the random uint to float to compare.
104
    // [Minor] We want to round down since when we do the comparison we use <= instead of <
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
110
111
112
113
    params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
    params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
    params.rp_dropout = 1.f / params.p_dropout;
    TORCH_CHECK(p_dropout < 1.f);
    set_alpha(params.scale_dropout, params.rp_dropout, data_type);

    params.is_causal = is_causal;
}

114
std::vector<at::Tensor>
Tri Dao's avatar
Tri Dao committed
115
116
117
118
119
120
121
122
123
124
125
mha_fwd(const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
        const at::Tensor &cu_seqlens,  // b+1
        const float p_dropout,
        const int max_seq_len,
        const float softmax_scale,
        const bool zero_tensors,
        const bool is_causal,
        const bool return_softmax,
        c10::optional<at::Generator> gen_) {

    auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
126
    bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Tri Dao's avatar
Tri Dao committed
127
    bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
Tri Dao's avatar
Tri Dao committed
128
    TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    bool is_dropout = p_dropout > 0.0;
    Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);

    TORCH_CHECK(qkv.is_cuda())
    TORCH_CHECK(cu_seqlens.is_cuda())

    TORCH_CHECK(qkv.is_contiguous())
    TORCH_CHECK(cu_seqlens.is_contiguous())

    TORCH_CHECK(cu_seqlens.dim() == 1);
    TORCH_CHECK(qkv.dim() == 4);

    const auto sizes = qkv.sizes();

    TORCH_CHECK(sizes[THREE_DIM] == 3);

    const int batch_size = cu_seqlens.numel() - 1;
    const int total = sizes[TOTAL_DIM];
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
    TORCH_CHECK(batch_size > 0);
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);

Tri Dao's avatar
Tri Dao committed
153
    int base_N = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    int seq_len = 512;
    if( max_seq_len <= 128 ) {
        seq_len = 128;
    } else if( max_seq_len <= 256 ) {
        seq_len = 256;
    } else {
        seq_len = ((max_seq_len + base_N - 1) / base_N) * base_N;
    }
    bool loop = seq_len > base_N;

    auto opts = qkv.options();

    auto ctx = torch::empty({ total, num_heads, head_size }, opts);

    at::Tensor o_tmp;
Tri Dao's avatar
Tri Dao committed
169
    if (loop) { o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat)); }
Tri Dao's avatar
Tri Dao committed
170
171
172
173
174

    auto softmax_lse = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
    // auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));

    at::Tensor s;
Tri Dao's avatar
Tri Dao committed
175
    if (return_softmax) { s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); }
Tri Dao's avatar
Tri Dao committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    if( zero_tensors ) {
        ctx.zero_();
        softmax_lse.fill_(-std::numeric_limits<float>::infinity());
        if (loop) { o_tmp.zero_(); }
        if (return_softmax) {s.zero_();}
    }

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());


    set_params(launch_params.params,
               batch_size,
               seq_len,
               num_heads,
               head_size,
               qkv.data_ptr(),
               cu_seqlens.data_ptr(),
               ctx.data_ptr(),
               loop ? o_tmp.data_ptr() : nullptr,
               nullptr,
               return_softmax ? s.data_ptr() : nullptr,
               softmax_lse.data_ptr(),
               nullptr,
               p_dropout,
               softmax_scale,
               is_causal);

    run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
    // number of times random will be generated per thread, to offset philox counter in thc random
    // state
    int64_t counter_offset = launch_params.elts_per_thread;
    at::PhiloxCudaState rng_engine_inputs;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    run_fmha_fp16_sm80(launch_params, /*configure=*/false);

    std::vector<at::Tensor> result = {ctx, softmax_lse};
    if (return_softmax) {result.push_back(s);}
    return result;
}


std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout,  // total x num_heads, x head_size
        const at::Tensor &qkv,   // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
        const at::Tensor &out,   // total x num_heads x head_size
        at::Tensor &softmax,     // b x h x s x s softmax and dmask - will be overwritten with dP
Tri Dao's avatar
Tri Dao committed
230
        const at::Tensor &softmax_lse_,     // b x h x s softmax logsumexp
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
236
237
238
239
        const at::Tensor &cu_seqlens,  // b+1
        const float p_dropout,         // probability to drop
        const float softmax_scale,
        const int max_seq_len,          // max sequence length to choose the kernel
        const bool zero_tensors,
        const bool is_causal,
        c10::optional<at::Generator> gen_
) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
240
    bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Tri Dao's avatar
Tri Dao committed
241
    bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
Tri Dao's avatar
Tri Dao committed
242
    TORCH_CHECK((dprops->major == 8 && dprops->minor >= 0) || is_sm75);
Tri Dao's avatar
Tri Dao committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    auto launch = &run_fmha_dgrad_fp16_sm80;

    bool is_dropout = p_dropout > 0.0;
    auto stream = at::cuda::getCurrentCUDAStream().stream();

    TORCH_CHECK(qkv.dtype() == torch::kFloat16);
    TORCH_CHECK(dout.dtype() == torch::kFloat16);
    TORCH_CHECK(softmax.dtype() == torch::kFloat16);
    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);

    TORCH_CHECK(qkv.is_cuda());
    TORCH_CHECK(cu_seqlens.is_cuda());

    TORCH_CHECK(qkv.is_contiguous());
    TORCH_CHECK(cu_seqlens.is_contiguous());

    TORCH_CHECK(cu_seqlens.dim() == 1);
    TORCH_CHECK(qkv.dim() == 4);

    const auto sizes = qkv.sizes();

    TORCH_CHECK(sizes[THREE_DIM] == 3);

    const int batch_size = cu_seqlens.numel() - 1;
    const int total = sizes[TOTAL_DIM];
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
    TORCH_CHECK(batch_size > 0);
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
Tri Dao's avatar
Tri Dao committed
272
273
274
    if (head_size == 128) {  // TODO: eventually we should support SM86 and SM70 with d=128 as well
        TORCH_CHECK(is_sm80);
    }
Tri Dao's avatar
Tri Dao committed
275

Tri Dao's avatar
Tri Dao committed
276
    int base_N = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256;
Tri Dao's avatar
Tri Dao committed
277
278
279
280
281
282
283
284
285
286
    int seq_len = 512;
    if( max_seq_len <= 128 ) {
        seq_len = 128;
    } else if( max_seq_len <= 256 ) {
        seq_len = 256;
    } else {
        seq_len = ((max_seq_len + base_N - 1) / base_N) * base_N;
    }
    bool loop = seq_len > base_N;

Tri Dao's avatar
Tri Dao committed
287
288
289
    // It's possible the softmax_lse_ from the fwd has a different length since base_N could be different.
    auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, seq_len)}).contiguous();

Tri Dao's avatar
Tri Dao committed
290
291
292
293
    auto dqkv = torch::empty_like(qkv);
    auto opts = qkv.options();
    auto softmax_d = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
    at::Tensor dq_tmp;
Tri Dao's avatar
Tri Dao committed
294
    if (loop) { dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat)); }
Tri Dao's avatar
Tri Dao committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    if( zero_tensors ) {
        dqkv.zero_();
        softmax_d.zero_();
        if (loop) { dq_tmp.zero_(); }
    }

    Fused_multihead_attention_fprop_params params;

    set_params(params,
               batch_size,
               seq_len,
               num_heads,
               head_size,
               qkv.data_ptr(),
               cu_seqlens.data_ptr(),
               out.data_ptr(),
               loop ? dq_tmp.data_ptr() : nullptr,
               dout.data_ptr(),
               softmax.data_ptr(),  // softmax gets overwritten by dP!
               softmax_lse.data_ptr(),
               softmax_d.data_ptr(),
               p_dropout,
               softmax_scale,
               is_causal);

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

    // We're gonna reset the rng state in Python after this kernel, so the counter offset
Tri Dao's avatar
Tri Dao committed
325
    // here doesn't matter at all. We just choose an arbitrary number.
Tri Dao's avatar
Tri Dao committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    int64_t counter_offset = 4;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    Data_type acc_type = DATA_TYPE_FP32;
    params.dqkv_ptr = dqkv.data_ptr();

    launch(params, stream);
    return { dqkv, softmax, softmax_d };
    // std::vector<at::Tensor> result = {dqkv, softmax, softmax_d};
    // if (loop) {
    //   result.push_back(dq_tmp);
    // }
    // return result;
}

std::vector<at::Tensor>
mha_fwd_block(const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
              const at::Tensor &cu_seqlens,  // b+1
              const at::Tensor &blockmask,   // (seqlen / 256, seqlen / 16)
              const float p_dropout,
              const int max_seq_len,
              const float softmax_scale,
              const bool is_causal,
              const bool return_softmax,
              c10::optional<at::Generator> gen_) {

    auto dprops = at::cuda::getCurrentDeviceProperties();
    TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    bool is_dropout = p_dropout > 0.0;
    Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);

    bool loop = false;
    int seq_len = 256;
    if( max_seq_len > 256 ) {
        seq_len = ((max_seq_len + 256 - 1) / 256) * 256;
        loop = true;
    }

    TORCH_CHECK(qkv.is_cuda())
    TORCH_CHECK(cu_seqlens.is_cuda())
    TORCH_CHECK(blockmask.is_cuda())

    TORCH_CHECK(qkv.is_contiguous())
    TORCH_CHECK(cu_seqlens.is_contiguous())
    TORCH_CHECK(blockmask.is_contiguous())

    TORCH_CHECK(cu_seqlens.dim() == 1);
    TORCH_CHECK(qkv.dim() == 4);
    TORCH_CHECK(blockmask.dim() == 2);

    const auto sizes = qkv.sizes();

    TORCH_CHECK(sizes[THREE_DIM] == 3);

    const int batch_size = cu_seqlens.numel() - 1;
    const int total = sizes[TOTAL_DIM];
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
    TORCH_CHECK(batch_size > 0);
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64);
    auto opts = qkv.options();

    auto ctx = torch::zeros({ total, num_heads, head_size }, opts);

    at::Tensor o_tmp;
    if (loop) {
        // o_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
        o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
    }

    // auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
    auto softmax_lse = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));

    at::Tensor s;
    if (return_softmax) {
        s = torch::zeros({ batch_size, num_heads, seq_len, seq_len }, opts);
    }

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());


    set_params(launch_params.params,
               batch_size,
               seq_len,
               num_heads,
               head_size,
               qkv.data_ptr(),
               cu_seqlens.data_ptr(),
               ctx.data_ptr(),
               loop ? o_tmp.data_ptr() : nullptr,
               nullptr,
               return_softmax ? s.data_ptr() : nullptr,
               softmax_lse.data_ptr(),
               nullptr,
               p_dropout,
               softmax_scale,
               is_causal);
    launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());

    run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
    // number of times random will be generated per thread, to offset philox counter in thc random
    // state
    int64_t counter_offset = launch_params.elts_per_thread;
    at::PhiloxCudaState rng_engine_inputs;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    run_fmha_block_fp16_sm80(launch_params, /*configure=*/false);

    std::vector<at::Tensor> result = {ctx, softmax_lse};
    if (return_softmax) {result.push_back(s);}
    return result;
}

std::vector<at::Tensor>
mha_bwd_block(const at::Tensor &dout,  // total x num_heads, x head_size
              const at::Tensor &qkv,   // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
              const at::Tensor &out,   // total x num_heads x head_size
              at::Tensor &softmax,     // b x h x s x s softmax and dmask - will be overwritten with dP
              const at::Tensor &softmax_lse,     // b x h x s softmax logsumexp
              const at::Tensor &cu_seqlens,  // b+1
              const at::Tensor &blockmask,   // (seqlen / 256, seqlen / 16)
              const float p_dropout,         // probability to drop
              const float softmax_scale,
              const int max_seq_len,          // max sequence length to choose the kernel
              const bool is_causal,
              c10::optional<at::Generator> gen_
) {
    auto dprops = at::cuda::getCurrentDeviceProperties();
    TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
    bool loop = false;
    int seq_len = 256;
    auto launch = &run_fmha_block_dgrad_fp16_sm80;
    if (max_seq_len > 256) {
        seq_len = ((max_seq_len + 256 - 1) / 256) * 256;
        loop = true;
    }

    bool is_dropout = p_dropout > 0.0;
    auto stream = at::cuda::getCurrentCUDAStream().stream();

    TORCH_CHECK(qkv.dtype() == torch::kFloat16);
    TORCH_CHECK(dout.dtype() == torch::kFloat16);
    TORCH_CHECK(softmax.dtype() == torch::kFloat16);
    TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
    TORCH_CHECK(blockmask.dtype() == torch::kInt32);

    TORCH_CHECK(qkv.is_cuda());
    TORCH_CHECK(cu_seqlens.is_cuda());
    TORCH_CHECK(blockmask.is_cuda());

    TORCH_CHECK(qkv.is_contiguous());
    TORCH_CHECK(cu_seqlens.is_contiguous());
    TORCH_CHECK(blockmask.is_contiguous());

    TORCH_CHECK(cu_seqlens.dim() == 1);
    TORCH_CHECK(qkv.dim() == 4);
    TORCH_CHECK(blockmask.dim() == 2);

    const auto sizes = qkv.sizes();

    TORCH_CHECK(sizes[THREE_DIM] == 3);

    const int batch_size = cu_seqlens.numel() - 1;
    const int total = sizes[TOTAL_DIM];
    const int num_heads = sizes[H_DIM];
    const int head_size = sizes[D_DIM];
    TORCH_CHECK(batch_size > 0);
    TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64);

    auto dqkv = torch::zeros_like(qkv);
    auto opts = qkv.options();
    auto softmax_d = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
    at::Tensor dq_tmp;
    if (loop) {
        // dq_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
        dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
    }

    Fused_multihead_attention_fprop_params params;

    set_params(params,
               batch_size,
               seq_len,
               num_heads,
               head_size,
               qkv.data_ptr(),
               cu_seqlens.data_ptr(),
               out.data_ptr(),
               loop ? dq_tmp.data_ptr() : nullptr,
               dout.data_ptr(),
               softmax.data_ptr(),  // softmax gets overwritten by dP!
               softmax_lse.data_ptr(),
               softmax_d.data_ptr(),
               p_dropout,
               softmax_scale,
               is_causal);
    params.blockmask = static_cast<int *>(blockmask.data_ptr());

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

    // We're gonna reset the rng state in Python after this kernel, so the counter offset
    // here doesn't matter at all. We just choose an arbitrary number;
    int64_t counter_offset = 4;

    if( is_dropout ) {
        // See Note [Acquire lock when using random generators]
        std::lock_guard<std::mutex> lock(gen->mutex_);
        params.philox_args = gen->philox_cuda_state(counter_offset);
    }

    Data_type acc_type = DATA_TYPE_FP32;
    params.dqkv_ptr = dqkv.data_ptr();

    launch(params, stream);
    return { dqkv, softmax, softmax_d };
    // std::vector<at::Tensor> result = {dqkv, softmax, softmax_d};
    // if (loop) {
    //   result.push_back(dq_tmp);
    // }
    // return result;
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "Fused Multi-head Self-attention";
    m.def("fwd", &mha_fwd, "Forward pass");
    m.def("bwd", &mha_bwd, "Backward pass");
    m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
    m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
}