modules.cpp 81.2 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/modules.h"

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
12
#include <cudnn.h>
13
14
15
16

#include <stdexcept>
#include <string>
#include <vector>
17
#include <iostream>
18
19

#include "common/common.h"
20
#include "common/util/logging.h"
21
22
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
23
#include "transformer_engine/fused_attn.h"
24
25
26
27
28
29
30
31
32
33
34
35
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"

namespace transformer_engine {
namespace jax {

inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }

36
37
38
39
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
    return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

40
41
42
43
44
45
46
47
48
49
50
51
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
  switch (activation_enum) {
    case NVTE_Activation_Type::GELU: return 1;
    case NVTE_Activation_Type::GEGLU: return 2;
    case NVTE_Activation_Type::SILU: return 1;
    case NVTE_Activation_Type::SWIGLU: return 2;
    case NVTE_Activation_Type::RELU: return 1;
    case NVTE_Activation_Type::REGLU: return 2;
    case NVTE_Activation_Type::QGELU: return 1;
    case NVTE_Activation_Type::QGEGLU: return 2;
    case NVTE_Activation_Type::SRELU: return 1;
    case NVTE_Activation_Type::SREGLU: return 2;
52
53
54
55
56
57
58
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
    return -1;
  }
}

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
    auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
    return pybind11::bytes(str);
}

template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
    if (opaque_len != sizeof(T)) {
        throw std::runtime_error("Invalid opaque object size");
    }
    return reinterpret_cast<const T *>(opaque);
}

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
74
                                               DType out_dtype, size_t act_enum) {
75
76
77
78
    CustomCallCommonDescriptor desc;
    desc.shape.from_vector(shape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
79
    desc.act_enum = act_enum;
80
81
82
    return PackOpaque(desc);
}

83
84
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
85
86
                                                 DType out_dtype, DType wk_dtype,
                                                 size_t act_enum) {
87
88
89
90
91
92
    CustomCallCommonWkDescriptor desc;
    desc.shape.from_vector(shape);
    desc.wkshape.from_vector(wkshape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    desc.wk_dtype = wk_dtype;
93
    desc.act_enum = act_enum;
94
95
96
    return PackOpaque(desc);
}

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
    CustomCallNormDescriptor desc;
    desc.batch_size = batch_size;
    desc.hidden_size = hidden_size;
    desc.wkspace_size = wkspace_size;
    desc.barrier_size = barrier_size;
    desc.dgamma_part_shape.from_vector(dgamma_part_shape);
    desc.dbeta_part_shape.from_vector(dbeta_part_shape);
    desc.x_dtype = x_dtype;
    desc.w_dtype = w_dtype;
    desc.wkspace_dtype = wkspace_dtype;
    desc.barrier_dtype = barrier_dtype;
    desc.dgamma_part_dtype = dgamma_part_dtype;
    desc.dbeta_part_dtype = dbeta_part_dtype;
    desc.zero_centered_gamma = zero_centered_gamma;
    desc.eps = eps;
    desc.sm_margin = sm_margin;
    return PackOpaque(desc);
119
120
}

121
122
123
124
125
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor) {
    return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
                                        dtype, scale_factor});
126
127
}

128
pybind11::bytes PackCustomCallFusedAttnDescriptor(
129
130
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
131
132
133
    size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training) {
zlsh80826's avatar
zlsh80826 committed
134
    return PackOpaque(CustomCallFusedAttnDescriptor{
135
        input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
136
137
        bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
        mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
138
139
}

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
    auto input_shape = std::vector<size_t>{rows, cols};
    auto output_shape = std::vector<size_t>{cols, rows};

    auto input_tensor = TensorWrapper(input, input_shape, dtype);
    auto transposed_tensor = TensorWrapper(output, output_shape, dtype);

    nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    void *input = buffers[0];
    void *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto rows = desc.shape.dims[0];
    auto cols = desc.shape.dims[1];
    assert(desc.in_dtype == desc.out_dtype);
    auto dtype = desc.out_dtype;

    TransposeImpl(input, rows, cols, dtype, stream, output);
}

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *input_cast = buffers[4];
    auto *input_cast_trans = buffers[5];
    float *amax_out = reinterpret_cast<float *>(buffers[6]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto input_trans_shape = std::vector<size_t>{n, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto input_cast_tensor =
        TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
                                                 desc.out_dtype, amax_out, scale, scale_inv);

    nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
                        input_cast_trans_tensor.data(), stream);
}

195
196
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
              cudaStream_t stream, float *scale_inverse, float *amax, void *output,
197
              NVTE_Activation_Type act_enum) {
198
199
    auto act_len = get_activation_len(act_enum);
    auto input_shape = std::vector<size_t>{m, n * act_len};
200
    auto output_shape = std::vector<size_t>{m, n};
201
202
203
204
    auto input_tensor = TensorWrapper(input, input_shape,
                                      static_cast<DType>(in_dtype));
    auto output_tensor = TensorWrapper(output, output_shape,
                                       static_cast<DType>(out_dtype), amax,
205
                                       scale, scale_inverse);
206
    switch (act_enum) {
207
    case NVTE_Activation_Type::GELU:
208
209
        nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
        break;
210
    case NVTE_Activation_Type::GEGLU:
211
212
        nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
        break;
213
214
    case NVTE_Activation_Type::SILU:
        nvte_silu(input_tensor.data(), output_tensor.data(), stream);
215
        break;
216
    case NVTE_Activation_Type::SWIGLU:
217
218
        nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
        break;
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
      case NVTE_Activation_Type::RELU:
        nvte_relu(input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
        break;
237
238
239
240
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
241
242
}

243
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
244
245
246
247
248
249
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
250
    auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
251

252
253
    ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
             nullptr, nullptr, output, act_enum);
254
255
}

256
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
273
    auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
274

275
276
    ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
             scale_inv, amax_out, output, act_enum);
277
278
}

279
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
280
    auto *input = buffers[0];
281
    auto *act_input = buffers[1];
282
283
284
285
286
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
287
    auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
288
289

    auto act_len = get_activation_len(act_enum);
290
    auto input_shape = std::vector<size_t>{m, n};
291
292
    auto act_input_shape = std::vector<size_t>{m, n * act_len};
    auto output_shape = std::vector<size_t>{m, n * act_len};
293
294

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
295
    auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
296
297
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

298
    switch (act_enum) {
299
      case NVTE_Activation_Type::GELU:
300
301
302
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(),
                   output_tensor.data(), stream);
        break;
303
      case NVTE_Activation_Type::GEGLU:
304
305
306
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
307
308
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(),
309
310
                    output_tensor.data(), stream);
        break;
311
      case NVTE_Activation_Type::SWIGLU:
312
313
314
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(),
                     output_tensor.data(), stream);
        break;
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(),
                    output_tensor.data(), stream);
        break;
339
340
341
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
342
343
344
    }
}

345
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
346
347
                                                         DType in_dtype, DType out_dtype) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
348
    auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
349
350
351
352
353
    auto output_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
    auto dbias_shape = std::vector<size_t>{hidden_size};

    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
354
    auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
355
356
357
358
359
360
    auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
    auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
    auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);

    TensorWrapper dummy_workspace;

361
362
363
364
    // For now, all dbias_dact(-s) have the same workspace size
    nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
                                    output_tensor.data(), output_trans_tensor.data(),
                                    dbias_tensor.data(), dummy_workspace.data(), nullptr);
365
366
367
368
369

    auto work_shape = MakeShapeVector(dummy_workspace.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}

370
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
371
372
                             size_t opaque_len) {
    auto *input = buffers[0];
373
374
375
376
377
378
379
380
381
    auto *act_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    auto *dbias = buffers[7];
    float *amax_out = reinterpret_cast<float *>(buffers[8]);
    void *workspace_ptr = buffers[9];
382
383
384
385
386
387
388
389
390
391

    const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
392
    auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
393
    auto input_shape = std::vector<size_t>{m, n};
394
    auto act_input_shape = std::vector<size_t>{m, n};
395
396
397
398
399
    auto output_shape = std::vector<size_t>{m, n};
    auto output_trans_shape = std::vector<size_t>{n, m};
    auto dbias_shape = std::vector<size_t>{n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
400
    auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
401
402
403
404
405
406
407
408
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

    auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

409
    switch (act_enum) {
410
      case NVTE_Activation_Type::GELU:
411
412
413
414
        nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                        output_tensor.data(), output_trans_tensor.data(),
                                        dbias_tensor.data(), workspace.data(), stream);
        break;
415
416
      case NVTE_Activation_Type::SILU:
        nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
417
418
419
                                         output_tensor.data(), output_trans_tensor.data(),
                                         dbias_tensor.data(), workspace.data(), stream);
        break;
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
      case NVTE_Activation_Type::RELU:
        nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                        output_tensor.data(), output_trans_tensor.data(),
                                        dbias_tensor.data(), workspace.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                        output_tensor.data(), output_trans_tensor.data(),
                                        dbias_tensor.data(), workspace.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                        output_tensor.data(), output_trans_tensor.data(),
                                        dbias_tensor.data(), workspace.data(), stream);
        break;
435
      default:
436
        NVTE_ERROR("Unsupported ActivationEnum");
437
        break;
438
439
440
    }
}

441
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
442
443
                             size_t opaque_len) {
    auto *input = buffers[0];
444
    auto *act_input = buffers[1];
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    float *amax_out = reinterpret_cast<float *>(buffers[7]);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
461
    auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
462
    auto input_shape = desc.shape.to_vector();
463
    auto act_input_shape = std::vector<size_t>{m, n * 2};
464
465
466
467
    auto output_shape = std::vector<size_t>{m, n * 2};
    auto output_trans_shape = std::vector<size_t>{n * 2, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
468
    auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
469
470
471
472
473
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);

474
    switch (act_enum) {
475
      case NVTE_Activation_Type::GEGLU:
476
477
478
479
        nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), output_trans_tensor.data(),
                                   stream);
        break;
480
      case NVTE_Activation_Type::SWIGLU:
481
482
483
484
        nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), output_trans_tensor.data(),
                                   stream);
        break;
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), output_trans_tensor.data(),
                                   stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), output_trans_tensor.data(),
                                   stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), output_trans_tensor.data(),
                                   stream);
        break;
500
501
502
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
503
504
505
    }
}

506
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
507
508
509
510
511
512
513
514
515
516
517
518
519
                                                         DType in_dtype, DType out_dtype) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
    auto dbias_shape = std::vector<size_t>{hidden_size};

    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
    auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
    auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);

    TensorWrapper dummy_workspace;

520
521
522
    nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
                              output_trans_tensor.data(), dbias_tensor.data(),
                              dummy_workspace.data(), nullptr);
523
524
525
526
527

    auto work_shape = MakeShapeVector(dummy_workspace.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}

528
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
529
530
                             size_t opaque_len) {
    auto *input = buffers[0];
531
532
533
534
535
536
537
538
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    auto *output_trans = buffers[5];
    auto *dbias = buffers[6];
    float *amax_out = reinterpret_cast<float *>(buffers[7]);
    void *workspace_ptr = buffers[8];
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

    const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};
    auto output_trans_shape = std::vector<size_t>{n, m};
    auto dbias_shape = std::vector<size_t>{n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

    auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

563
564
565
    nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(),
                              output_trans_tensor.data(), dbias_tensor.data(),
                              workspace.data(), stream);
566
567
}

568
569
570
571
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps) {
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};

    // empty tensor wrappers are okay just to get workspace size
    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);

    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
589

590
591
592
593
594
595
596
597
598
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
    } else {
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
                         rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
    }
599

600
601
602
603
    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
604
605
}

606
607
608
609
610
611
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
                          size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
                          DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
                          DType out_dtype, void *workspace, DType work_dtype, void *barrier,
                          DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
                          float *scale_inv, cudaStream_t stream) {
612
613
614
615
616
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto workspace_shape = std::vector<size_t>{workspace_size};
    auto barrier_shape = std::vector<size_t>{barrier_size};
617
618
619
620
621
622
623
624
625
626
627
    auto is_layer_norm = (bias) ? true : false;

    auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);

    // assume output dtype = input dtype
    // If we need mixed I/O precision in the future, we need an additional
    // parameter for output type
    auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);

628
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
629
630
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;

631
632
633
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);

634
635
636
637
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);

638
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
639
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
640
                           num_sm, workspace_tensor.data(), barrier_tensor.data());
641
    } else {
642
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
643
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
644
645
                         rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
646
    }
647
}
648

649
650
651
652
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps) {
653
654
655
656
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto intermediates_dtype = DType::kFloat32;
657

658
659
660
661
662
663
664
    // empty tensor wrappers are okay just to get workspace size
    auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
    auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
665

666
667
668
669
670
    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
671

672
673
    // initialize dBeta information here -- layernorm will modify but RMSnorm will not
    std::vector<size_t> dbeta_part_shape;
674
    if (is_layer_norm) {
675
676
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
677

678
679
680
681
682
683
684
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(),
                           dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());

        dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
685
    } else {
686
687
688
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
689
690
                         dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
691
692

        dbeta_part_shape = std::vector<size_t>{0, 0};
693
    }
694
695
696
697
698
699
700
701

    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
                                std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
                                std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
702
703
}

704
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
705
                           size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
706
707
708
709
710
                           bool zero_centered_gamma, float eps, void *input, DType in_dtype,
                           void *weight, DType w_dtype, void *ograd, void *workspace,
                           DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
                           void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
                           DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
711
712
713
714
                           cudaStream_t stream) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
    auto intermediates_dtype = DType::kFloat32;
    auto is_layer_norm = (dbeta) ? true : false;

    // assume input type = output type
    auto *grad_output = ograd;
    auto x_dtype = in_dtype;
    auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);

    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);

    auto *x = input;
    auto x_tensor = TensorWrapper(x, input_shape, x_dtype);

    auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
    auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);

732
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
733
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
734

735
736
737
738
    auto workspace_shape = std::vector<size_t>{wkspace_size};
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
    auto barrier_shape = std::vector<size_t>{barrier_size};
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
739
740
    auto dgamma_part_tensor =
        TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
741
742
743
744

    if (is_layer_norm) {
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
745
746
        auto dbeta_part_tensor =
            TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
747

748
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
749
750
751
752
753
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
                           dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                           barrier_tensor.data());
    } else {
754
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
                         dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
    }
}

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *amax = reinterpret_cast<float *>(buffers[3]);
    auto *scale = reinterpret_cast<float *>(buffers[4]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
    auto *output = buffers[6];
    auto *mu = buffers[7];
    auto *rsigma = buffers[8];
    auto *amax_out = buffers[9];
774
775
    auto *workspace = buffers[10];
    auto *barrier = buffers[11];
776
777
778
    assert(amax_out == amax);

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
779
780
781
782
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
783
784
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
785
786
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
787
    auto eps = desc.eps;
788
    auto zero_centered_gamma = desc.zero_centered_gamma;
789
    auto sm_margin = desc.sm_margin;
790
791
792

    auto out_dtype = DType::kFloat8E4M3;

793
794
795
796
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
797
798
799
800
801
802
803
804
805
}

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *output = buffers[3];
    auto *mu = buffers[4];
    auto *rsigma = buffers[5];
806
807
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
808
809
810
811
812
813

    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
814
815
816
817
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
818
819
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
820
821
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
822
823
    auto eps = desc.eps;
    auto out_dtype = in_dtype;
824
    auto zero_centered_gamma = desc.zero_centered_gamma;
825
    auto sm_margin = desc.sm_margin;
826

827
828
829
830
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
831
832
833
834
835
}

void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);

836
837
838
839
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
840
841
    auto dgamma_part_shape = desc.dgamma_part_shape;
    auto dbeta_part_shape = desc.dbeta_part_shape;
842
843
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
844
845
846
847
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = desc.dbeta_part_dtype;
848
    auto eps = desc.eps;
849
    auto zero_centered_gamma = desc.zero_centered_gamma;
850
    auto sm_margin = desc.sm_margin;
851
852
853
854
855
856
857
858
859

    auto *ograd = buffers[0];
    auto *mu = buffers[1];
    auto *rsigma = buffers[2];
    auto *input = buffers[3];
    auto *weight = buffers[4];
    auto *xgrad = buffers[5];
    auto *wgrad = buffers[6];
    auto *dbeta = buffers[7];
860
861
862
863
864
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
    auto *dgamma_part = buffers[10];
    auto *dbeta_part = buffers[11];

865
866
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
                          dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
867
868
869
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
870
871
872
873
874
875
876
877
878
879
880
}

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *amax = reinterpret_cast<float *>(buffers[2]);
    auto *scale = reinterpret_cast<float *>(buffers[3]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *rsigma = buffers[6];
    auto *amax_out = buffers[7];
881
882
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
883
884
885
886
887
888
    assert(amax_out == amax);

    void *bias = nullptr;
    void *mu = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
889
890
891
892
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
893
894
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
895
896
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
897
    auto eps = desc.eps;
898
    auto zero_centered_gamma = desc.zero_centered_gamma;
899
    auto sm_margin = desc.sm_margin;
900
901
    auto out_dtype = DType::kFloat8E4M3;

902
903
904
905
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
906
907
908
909
910
911
912
}

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *output = buffers[2];
    auto *rsigma = buffers[3];
913
914
    auto *workspace = buffers[4];
    auto *barrier = buffers[5];
915
916
917
918
919
920
921
922

    void *bias = nullptr;
    void *mu = nullptr;
    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
923
924
925
926
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
927
928
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
929
930
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
931
    auto eps = desc.eps;
932
    auto zero_centered_gamma = desc.zero_centered_gamma;
933
    auto sm_margin = desc.sm_margin;
934
935
    auto out_dtype = in_dtype;

936
937
938
939
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
940
941
942
943
944
945
946
947
948
}

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *ograd = buffers[0];
    auto *rsigma = buffers[1];
    auto *input = buffers[2];
    auto *weight = buffers[3];
    auto *xgrad = buffers[4];
    auto *wgrad = buffers[5];
949
950
951
952
953
954
955
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
    auto *dgamma_part = buffers[8];

    void *mu = nullptr;
    void *dbeta = nullptr;
    void *dbeta_part = nullptr;
956
957

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
958
959
960
961
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
962
963
964
    auto dgamma_part_shape = desc.dgamma_part_shape;
    Shape dbeta_part_shape;
    dbeta_part_shape.from_vector({0, 0});
965
966
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
967
968
969
970
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = DType::kByte;
971
    auto eps = desc.eps;
972
    auto zero_centered_gamma = desc.zero_centered_gamma;
973

974
975
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
                          dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
976
977
978
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
}

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    auto *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);

    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);

    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);

    nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
1021
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);
    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
                                stream);
}

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
1038
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                                 dgrad_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                size_t opaque_len) {
    auto *input = buffers[0];
    auto *mask = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
1056
1057
    auto io_shape =
        std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
1058
    auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, io_shape, dtype);
    // Mask would be casted to uint8_t
    auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
    auto output_tensor = TensorWrapper(output, io_shape, dtype);

    nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
                                       output_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 size_t opaque_len) {
    // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
    ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
1082
    auto attn_batch = desc.batch_size * desc.head_dim;
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);

    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                    desc.scale_factor, stream);
}

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
1101
    auto attn_batch = desc.batch_size * desc.head_dim;
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_backward(
        grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
        desc.scale_factor, stream);
}
1113

1114
1115
1116
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
1117
                                            size_t q_attn_heads, size_t kv_attn_heads,
1118
1119
1120
1121
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim) {
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
1122
        mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
zlsh80826's avatar
zlsh80826 committed
1123
        head_dim);
1124
1125
1126
    return backend;
}

1127
1128
1129
1130
1131
1132
/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
1133
1134
1135
1136
1137
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
                                       const CustomCallFusedAttnDescriptor *desc,
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
1138
1139
1140
1141
    auto input_batch = desc->input_batch;
    auto bias_batch = desc->bias_batch;
    auto attn_heads = desc->attn_heads;
    auto bias_heads = desc->bias_heads;
1142
1143
1144
1145
1146
1147
1148
1149
    auto q_max_seqlen = desc->q_max_seqlen;
    auto kv_max_seqlen = desc->kv_max_seqlen;

    // all backends need softmax but expect different shapes/dtypes
    // start with the max512 sequence length softmax shape/dtype and correct later
    tensor_pack->size = 1;
    Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
    softmax_aux->data.dptr = softmax_buf;
1150
    softmax_aux->data.shape =
1151
        std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
    softmax_aux->data.dtype = desc->dtype;

    // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
        tensor_pack->size = 2;
        Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
        rng_state_aux->data.dptr = rng_state_buf;
        rng_state_aux->data.shape = std::vector<size_t>{2};
        rng_state_aux->data.dtype = DType::kInt64;
        // correct softmax shape/dtype
        softmax_aux->data.shape.at(3) = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
        softmax_aux->data.dtype = DType::kFloat32;

        // include bias if enabled
        if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
            tensor_pack->size = 3;
            Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
            bias_aux->data.dptr = bias_buf;
1170
1171
            bias_aux->data.shape =
                std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
            bias_aux->data.dtype = desc->dtype;
        }
    }
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
1185
1186
1187
1188
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
                                        const CustomCallFusedAttnDescriptor *desc,
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
1189
1190
1191
1192
1193
1194
1195
1196
1197
    // Backward calls put everything into the tensor pack for every backend
    // so we set dummy bias_type and backend choices here to follow the correct code path
    auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
    auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
    PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend,
                                      softmax_buf, rng_state_buf, bias_buf);

    // correct softmax shape for max512 sequence length kernel
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
1198
        Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
1199
1200
1201
1202
1203
        softmax_aux->data.shape.at(3) = desc->kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
        softmax_aux->data.dtype = desc->dtype;
    }
}

1204
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
1205
1206
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
1207
1208
1209
1210
1211
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
    // For qkv_packed
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
1212

1213
    // For kv_packed
1214
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
1215
    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
1216
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
1217
1218
    auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

1219
1220
1221
1222
1223
1224
1225
    // For separate q, k, v
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_shape = k_shape;
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1226
1227
    auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1228
    // F16 doesn't use this tensor
1229
1230
1231
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
    auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

1232
    auto q_cu_seqlens_tensor =
1233
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1234
    auto kv_cu_seqlens_tensor =
1235
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1236
1237
1238
1239
1240
1241

    auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

1242
1243
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1244
    TensorWrapper query_workspace_tensor;
1245
1246
1247
1248
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        assert(q_max_seqlen == kv_max_seqlen);
        nvte_fused_attn_fwd_qkvpacked(
            qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
1249
1250
1251
1252
1253
1254
            &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
            q_max_seqlen, is_training, scaling_factor,
            dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
            nullptr);
1255
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
1256
1257
1258
1259
1260
1261
1262
1263
        nvte_fused_attn_fwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
            &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_rng_state_tensor.data(), q_max_seqlen,
            kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
            mask_type, query_workspace_tensor.data(), nullptr);
1264
1265
1266
1267
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                            s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
1268
1269
1270
1271
1272
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            dummy_rng_state_tensor.data(),
                            q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
                            dropout_probability, qkv_layout, bias_type, mask_type,
1273
1274
1275
1276
                            query_workspace_tensor.data(), nullptr);
    } else {
        NVTE_ERROR("Unsupported QKVLayout.");
    }
1277

1278
1279
    auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
1280
}
1281

1282
1283
1284
1285
1286
1287
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
    size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
    bool is_training) {
    auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
1288
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
1289
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
1290

1291
    auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
1292
1293
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1294
1295
    // F16 doesn't use s_tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
1296

1297
    auto q_cu_seqlens_tensor =
1298
        TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
1299
    auto kv_cu_seqlens_tensor =
1300
        TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
1301
1302
1303

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
1304
1305
1306

    TensorWrapper query_workspace_tensor;

1307
1308
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        assert(q_max_seqlen == kv_max_seqlen);
        auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
        auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
        nvte_fused_attn_bwd_qkvpacked(
            qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
1319
1320
1321
1322
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            q_max_seqlen, scaling_factor, dropout_probability,
            qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto kv_shape =
            std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
        auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
        nvte_fused_attn_bwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
1336
1337
1338
1339
1340
1341
            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            q_max_seqlen, kv_max_seqlen, scaling_factor,
            dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
            nullptr);
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
        auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
        auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
        nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                            doutput_tensor.data(),
                            s_tensor.data(),  // not used for F16
                            s_tensor.data(),  // not used for F16
1356
1357
1358
1359
1360
1361
1362
1363
1364
                            &aux_input_tensors,
                            dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                            dbias_tensor.data(),
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
                            qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
                            nullptr);
1365
1366
1367
1368
1369
1370
    } else {
        NVTE_ERROR("Unsupported QKVLayout.");
    }

    auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
1371
1372
1373
1374
1375
1376
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1377
1378
    /* Input buffers from XLA */
    /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
1379
1380
1381
1382
1383
    void *bias = buffers[3];
    void *q_cu_seqlens = buffers[4];
    void *kv_cu_seqlens = buffers[5];
    void *seed = buffers[6];

1384
    /* Output buffer from XLA */
1385
1386
1387
1388
1389
    void *output = buffers[7];
    void *softmax_aux = buffers[8];
    void *rng_state = buffers[9];
    void *workspace = buffers[10];

1390
    /* Descriptor */
1391
1392
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1393
1394
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1395
    auto attn_heads = descriptor.attn_heads;
1396
    auto num_gqa_groups = descriptor.num_gqa_groups;
1397
    auto bias_heads = descriptor.bias_heads;
1398
1399
1400
1401
1402
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1403
1404
    auto qkv_layout = descriptor.qkv_layout;
    auto dtype = descriptor.dtype;
1405

1406
    /* Input tensors */
1407
1408
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1409
    auto v_shape = k_shape;
1410
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1411
1412
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

1413
    /* Output tensors */
1414
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
1415
1416
    auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto o_tensor = TensorWrapper(output, o_shape, dtype);
1417
    auto q_cu_seqlens_tensor =
1418
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1419
    auto kv_cu_seqlens_tensor =
1420
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1421

1422
    /* Prepare RNG state */
1423
1424
1425
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1426
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1427
1428
1429
        head_dim);
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

1430
    /* Auxiliary tensors (to be propagated to the backward pass later) */
1431
1432
1433
1434
1435
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);
    PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                      softmax_aux);

1436
    /* cuDNN workspace */
1437
1438
1439
    auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                          descriptor.wkspace_dtype);

1440
1441
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1442
1443
1444
1445
1446
1447
1448
    /* Call the underly NVTE API */
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        auto qkv = buffers[0];
        auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
        nvte_fused_attn_fwd_qkvpacked(
            qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
1449
1450
1451
1452
1453
1454
            &aux_output_tensors, q_cu_seqlens_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
            descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            workspace_tensor.data(), stream);
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto kv = buffers[1];
        auto kv_shape =
            std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
        nvte_fused_attn_fwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
            &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
1466
1467
1468
1469
1470
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
            descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
            mask_type, workspace_tensor.data(), stream);
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto k = buffers[1];
        auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(k, k_shape, dtype);
        auto v = buffers[2];
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(v, v_shape, dtype);
        nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                            s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
1484
1485
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
1486
                            rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
1487
1488
1489
                            descriptor.is_training, scaling_factor,
                            dropout_probability, qkv_layout, bias_type, mask_type,
                            workspace_tensor.data(), stream);
1490
1491
1492
    } else {
        NVTE_ERROR("Unsupported qkv_layout.");
    }
1493
1494
1495
1496
1497

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
1498
1499
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
1500
1501
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
1502
1503
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1504
    auto v_shape = k_shape;
1505
1506
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521

    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

    auto q_cu_seqlens_tensor =
1522
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1523
    auto kv_cu_seqlens_tensor =
1524
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1525
1526
1527
1528
1529

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);

    TensorWrapper query_workspace_tensor;
1530
1531
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1532
1533
1534
1535
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
1536
1537
1538
1539
1540
1541
1542
1543
                        &aux_input_tensors,
                        dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(),
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                        dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                        q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
                        qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
1544
1545
1546
1547
1548
1549
1550
1551
1552

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1553
1554
    /* Input buffers from XLA */
    /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
1555
1556
1557
1558
1559
1560
1561
1562
    void *bias = buffers[3];
    void *softmax_aux = buffers[4];
    void *rng_state = buffers[5];
    void *output = buffers[6];
    void *doutput = buffers[7];
    void *q_cu_seqlens = buffers[8];
    void *kv_cu_seqlens = buffers[9];

1563
1564
    /* Output buffer from XLA */
    /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */
1565
1566
1567
    void *dbias = buffers[13];
    void *workspace = buffers[14];

1568
    /* Descriptor */
1569
1570
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1571
1572
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1573
    auto attn_heads = descriptor.attn_heads;
1574
    auto num_gqa_groups = descriptor.num_gqa_groups;
1575
    auto bias_heads = descriptor.bias_heads;
1576
1577
1578
1579
1580
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1581
1582
    auto qkv_layout = descriptor.qkv_layout;
    auto dtype = descriptor.dtype;
1583

1584
    /* Input tensors */
1585
1586
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1587
1588
1589
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

1590
    /* Output tensors */
1591
1592
1593
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
    auto q_cu_seqlens_tensor =
1594
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1595
    auto kv_cu_seqlens_tensor =
1596
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1597

1598
    /* Auxiliary tensors (propagated from the forward pass) */
1599
1600
1601
1602
    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1603
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1604
1605
1606
1607
        head_dim);
    PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                       rng_state, bias);

1608
    /* cuDNN workspace */
1609
1610
1611
1612
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);

1613
1614
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
    /* Call the underly NVTE API */
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        auto qkv = buffers[0];
        auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
        auto dqkv = buffers[10];
        auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
        nvte_fused_attn_bwd_qkvpacked(
            qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
1627
1628
1629
1630
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            q_max_seqlen, scaling_factor, dropout_probability,
            qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto kv = buffers[1];
        auto kv_shape =
            std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
        auto dq = buffers[10];
        auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
        auto dkv = buffers[11];
        auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
        nvte_fused_attn_bwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
1648
1649
1650
1651
1652
            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
            q_max_seqlen, kv_max_seqlen, scaling_factor,
            dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto k = buffers[1];
        auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(k, k_shape, dtype);
        auto v = buffers[2];
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(v, v_shape, dtype);
        auto dq = buffers[10];
        auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
        auto dk = buffers[11];
        auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
        auto dv = buffers[12];
        auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
        nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                            doutput_tensor.data(),
                            s_tensor.data(),  // not used for F16
                            s_tensor.data(),  // not used for F16
                            &aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
1674
1675
1676
1677
1678
1679
                            dv_tensor.data(), dbias_tensor.data(),
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
                            q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
                            qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
1680
1681
1682
    } else {
        NVTE_ERROR("Unsupported qkv_layout.");
    }
1683
1684
1685
1686

    nvte_tensor_pack_destroy(&aux_input_tensors);
}

1687
1688
}  // namespace jax
}  // namespace transformer_engine