modules.cpp 48.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/modules.h"

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
12
#include <cudnn.h>
13
14
15
16
17
18
19
20

#include <functional>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>

#include "common/common.h"
21
#include "common/util/cuda_runtime.h"
22
23
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
24
#include "transformer_engine/fused_attn.h"
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"

namespace transformer_engine {
namespace jax {

constexpr size_t kCublasLtForwardWorkspaceSize = 32 * 1024 * 1024;
constexpr size_t kCublasLtBackwardWorkspaceSize = 32 * 1024 * 1024;

inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }

template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
    auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
    return pybind11::bytes(str);
}

template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
    if (opaque_len != sizeof(T)) {
        throw std::runtime_error("Invalid opaque object size");
    }
    return reinterpret_cast<const T *>(opaque);
}

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype) {
    CustomCallCommonDescriptor desc;
    desc.shape.from_vector(shape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    return PackOpaque(desc);
}

pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
                                             DType B_dtype, DType D_dtype, bool transa, bool transb,
                                             bool use_split_accumulator) {
    return PackOpaque(CustomCallGemmDescriptor{m, n, k, A_dtype, B_dtype, D_dtype, transa, transb,
                                               use_split_accumulator});
}

pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
72
                                             bool zero_centered_gamma, float eps, int sm_margin) {
73
    return PackOpaque(
74
        CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin});
75
76
77
78
79
80
81
82
83
}

pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
                                                size_t q_seqlen, size_t k_seqlen, DType dtype,
                                                float scale_factor) {
    return PackOpaque(
        SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor});
}

84
85
86
87
88
89
90
91
92
pybind11::bytes PackCustomCallFusedAttnDescriptor(
    size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
    return PackOpaque(CustomCallFusedAttnDescriptor{batch, num_head, q_max_seqlen, kv_max_seqlen,
                                                    head_dim, scaling_factor, dropout_probability,
                                                    bias_type, mask_type, dtype, is_training});
}

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
    auto input_shape = std::vector<size_t>{rows, cols};
    auto output_shape = std::vector<size_t>{cols, rows};

    auto input_tensor = TensorWrapper(input, input_shape, dtype);
    auto transposed_tensor = TensorWrapper(output, output_shape, dtype);

    nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    void *input = buffers[0];
    void *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto rows = desc.shape.dims[0];
    auto cols = desc.shape.dims[1];
    assert(desc.in_dtype == desc.out_dtype);
    auto dtype = desc.out_dtype;

    TransposeImpl(input, rows, cols, dtype, stream, output);
}

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *input_cast = buffers[4];
    auto *input_cast_trans = buffers[5];
    float *amax_out = reinterpret_cast<float *>(buffers[6]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto input_trans_shape = std::vector<size_t>{n, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto input_cast_tensor =
        TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
                                                 desc.out_dtype, amax_out, scale, scale_inv);

    nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
                        input_cast_trans_tensor.data(), stream);
}

void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
                   cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
    auto input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));

    auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                       scale, scale_inverse);

    nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}

void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
                  output);
}

void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
                  output);
}

void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

    nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}

void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    float *amax_out = reinterpret_cast<float *>(buffers[7]);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = desc.shape.to_vector();
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};
    auto output_trans_shape = std::vector<size_t>{n * 2, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(),
                               output_trans_tensor.data(), stream);
}

void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *A = buffers[0];
    auto *B = buffers[1];
    auto *A_scale_inverse = reinterpret_cast<float *>(buffers[2]);
    auto *B_scale_inverse = reinterpret_cast<float *>(buffers[3]);
    auto *D = buffers[4];

    // We transposes shape of A, B and D here to correctly invoke
    // cuBlasLt GEMM (col-major) for row-major data.
    const auto &desc = *UnpackOpaque<CustomCallGemmDescriptor>(opaque, opaque_len);

    auto m = desc.m;
    auto n = desc.n;
    auto k = desc.k;
    auto A_shape = std::vector<size_t>{k, m};
    auto A_tensor = TensorWrapper(A, A_shape, desc.A_dtype, nullptr, nullptr, A_scale_inverse);

    auto B_shape = std::vector<size_t>{n, k};
    auto B_tensor = TensorWrapper(B, B_shape, desc.B_dtype, nullptr, nullptr, B_scale_inverse);

    auto D_shape = std::vector<size_t>{n, m};
    auto D_tensor = TensorWrapper(D, D_shape, desc.D_dtype);

    auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);

    size_t workspace_size = kCublasLtForwardWorkspaceSize;
276
    auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
277
278
279
280
281
    auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);

    nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
                     null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
                     (desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false,
282
                     desc.use_split_accumulator, 0, stream);
283
284
}

285
286
287
288
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
                          int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
                          void *bias, void *output, DType out_dtype, void *mu, void *rsigma,
                          float *amax, float *scale, float *scale_inv, cudaStream_t stream) {
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    auto input_shape = std::vector<size_t>{n, hidden};
    auto weight_shape = std::vector<size_t>{hidden};
    auto intermediates_shape = std::vector<size_t>{n};
    auto is_layer_norm = (bias) ? true : false;

    auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);

    // assume output dtype = input dtype
    // If we need mixed I/O precision in the future, we need an additional
    // parameter for output type
    auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);

    // Create uninitialized workspace, barrier and init them on the first
    TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
305
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
306

307
308
309
310
311
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
    if (!is_layer_norm) {
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
    }

312
313
314
315
316
    // The first call is to query the required workspace
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);

317
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
318
319
320
321
322
323
324
325
326
327
328
329
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
                           num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
    } else {
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
                         rsigma_tensor.data(), stream, num_sm, dummy_workspace_tensor.data(),
                         dummy_barrier_tensor.data());
    }

    size_t workspace_size =
        dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) +
        dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());

330
    void *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
331
332
333
334
335
336
337
338
339
340
341
342

    auto workspace_tensor =
        TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());

    auto barrier_tensor =
        TensorWrapper(reinterpret_cast<char *>(workspace) + dummy_workspace_tensor.shape().data[0],
                      dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());

    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);

343
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
344
345
346
347
348
349
350
351
352
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
                           num_sm, workspace_tensor.data(), barrier_tensor.data());
    } else {
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
                         rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
    }
}

353
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
354
355
356
                           int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
                           void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad,
                           void *dbeta, cudaStream_t stream) {
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    auto input_shape = std::vector<size_t>{n, hidden};
    auto weight_shape = std::vector<size_t>{hidden};
    auto intermediates_shape = std::vector<size_t>{n};
    auto intermediates_dtype = DType::kFloat32;
    auto is_layer_norm = (dbeta) ? true : false;

    // assume input type = output type
    auto *grad_output = ograd;
    auto x_dtype = in_dtype;
    auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);

    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);

    auto *x = input;
    auto x_tensor = TensorWrapper(x, input_shape, x_dtype);

    auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
    auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);

    TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
    TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
379
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
380
381
    size_t dbeta_part_size{};

382
383
384
385
386
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
    if (!is_layer_norm) {
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
    }

387
388
389
390
391
    // The first call is to query the workspace
    if (is_layer_norm) {
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);

392
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(),
                           dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
                           num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());

        dbeta_part_size = dummy_dbeta_part_tensor.shape().data[0] *
                          dummy_dbeta_part_tensor.shape().data[1] *
                          typeToSize(dummy_dbeta_part_tensor.dtype());
    } else {
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
                         dummy_dgamma_part_tensor.data(), stream, num_sm,
                         dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
    }

    size_t workspace_size =
        dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype());
    size_t barrier_size =
        dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
    size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
                              dummy_dgamma_part_tensor.shape().data[1] *
                              typeToSize(dummy_dgamma_part_tensor.dtype());

416
417
    auto [workspace, dgamma_part, dbeta_part, barrier] = WorkspaceManager::Instance().GetWorkspace(
        workspace_size, dgamma_part_size, dbeta_part_size, barrier_size);
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

    auto workspace_tensor =
        TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());

    auto barrier_tensor =
        TensorWrapper(barrier, dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());

    auto dgamma_part_tensor = TensorWrapper(dgamma_part, dummy_dgamma_part_tensor.shape(),
                                            dummy_dgamma_part_tensor.dtype());

    if (is_layer_norm) {
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
        auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(),
                                               dummy_dbeta_part_tensor.dtype());

434
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
                           dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                           barrier_tensor.data());
    } else {
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
                         dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
    }
}

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *amax = reinterpret_cast<float *>(buffers[3]);
    auto *scale = reinterpret_cast<float *>(buffers[4]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
    auto *output = buffers[6];
    auto *mu = buffers[7];
    auto *rsigma = buffers[8];
    auto *amax_out = buffers[9];
    assert(amax_out == amax);

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
467
    auto zero_centered_gamma = desc.zero_centered_gamma;
468
    auto sm_margin = desc.sm_margin;
469
470
471

    auto out_dtype = DType::kFloat8E4M3;

472
473
474
    LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                         w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
}

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *output = buffers[3];
    auto *mu = buffers[4];
    auto *rsigma = buffers[5];

    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
    auto out_dtype = in_dtype;
496
    auto zero_centered_gamma = desc.zero_centered_gamma;
497
    auto sm_margin = desc.sm_margin;
498

499
500
501
    LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                         w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
502
503
504
505
506
507
508
509
510
511
}

void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);

    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
512
    auto zero_centered_gamma = desc.zero_centered_gamma;
513
    auto sm_margin = desc.sm_margin;
514
515
516
517
518
519
520
521
522
523

    auto *ograd = buffers[0];
    auto *mu = buffers[1];
    auto *rsigma = buffers[2];
    auto *input = buffers[3];
    auto *weight = buffers[4];
    auto *xgrad = buffers[5];
    auto *wgrad = buffers[6];
    auto *dbeta = buffers[7];

524
525
    LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                          w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
}

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *amax = reinterpret_cast<float *>(buffers[2]);
    auto *scale = reinterpret_cast<float *>(buffers[3]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *rsigma = buffers[6];
    auto *amax_out = buffers[7];
    assert(amax_out == amax);

    void *bias = nullptr;
    void *mu = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
548
    auto zero_centered_gamma = desc.zero_centered_gamma;
549
    auto sm_margin = desc.sm_margin;
550
551
    auto out_dtype = DType::kFloat8E4M3;

552
553
554
    LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                         w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
}

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *output = buffers[2];
    auto *rsigma = buffers[3];

    void *bias = nullptr;
    void *mu = nullptr;
    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
575
    auto zero_centered_gamma = desc.zero_centered_gamma;
576
    auto sm_margin = desc.sm_margin;
577
578
    auto out_dtype = in_dtype;

579
580
581
    LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                         w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
}

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *ograd = buffers[0];
    auto *rsigma = buffers[1];
    auto *input = buffers[2];
    auto *weight = buffers[3];
    auto *xgrad = buffers[4];
    auto *wgrad = buffers[5];

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
    auto n = desc.n;
    auto hidden = desc.hidden;
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
    auto eps = desc.eps;
598
    auto zero_centered_gamma = desc.zero_centered_gamma;
599
    auto sm_margin = desc.sm_margin;
600
601
602
603

    void *mu = nullptr;
    void *dbeta = nullptr;

604
605
    LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
                          w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
}

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    auto *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);

    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);

    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);

    nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
    auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);
    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
                                stream);
}

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
    auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                                 dgrad_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                size_t opaque_len) {
    auto *input = buffers[0];
    auto *mask = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
    auto io_shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
    auto mask_shape = std::vector<size_t>{desc.pad_batch, 1, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, io_shape, dtype);
    // Mask would be casted to uint8_t
    auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
    auto output_tensor = TensorWrapper(output, io_shape, dtype);

    nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
                                       output_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 size_t opaque_len) {
    // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
    ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
    auto attn_batch = desc.batch * desc.heads;
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);

    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                    desc.scale_factor, stream);
}

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
    auto attn_batch = desc.batch * desc.heads;
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_backward(
        grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
        desc.scale_factor, stream);
}
739

740
741
742
743
744
745
746
747
748
749
750
751
752
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim) {
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
        mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
    return backend;
}

void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
753
754
755
756
757
758
759
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input
    void *qkv = buffers[0];
    void *bias = buffers[1];
    void *cu_seqlens = buffers[2];
760
    void *seed = buffers[3];
761
762
763
764

    // output
    void *output = buffers[4];
    void *softmax_aux = buffers[5];
765
    void *rng_state = buffers[6];
766
767
768
769
770
771

    auto batch = descriptor.batch;
    auto num_head = descriptor.num_head;
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
    auto head_dim = descriptor.head_dim;
772
773
774
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
775
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
776
777
778
779
780
781
782
783

    NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
               "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");

    auto dtype = descriptor.dtype;
    auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
    auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};

784
    // input tensors
785
786
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
787
788
    auto cu_seqlens_tensor =
        TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
789

790
    // output tensors
791
792
793
    auto o_tensor =
        TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);

794
795
    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
796

797
798
799
800
801
802
803
    // aux tensors
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);

    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
        mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
804
805
806
807
808
809

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;

810
811
812
813
814
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
                                  rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
                                  descriptor.scaling_factor, dropout_probability, qkv_layout,
                                  bias_type, mask_type, query_workspace_tensor.data(), stream);
815
816
817
818

    auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
    output_s->data.dptr = softmax_aux;

819
    auto workspace_size = query_workspace_tensor.shape().data[0];
820
    auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
821
822
823
824
825
826
    auto workspace_tensor =
        TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
                                  rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
827
828
                                  descriptor.scaling_factor, dropout_probability, qkv_layout,
                                  bias_type, mask_type, workspace_tensor.data(), stream);
829
830
831
832

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

833
834
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
835
836
837
838
839
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input
    void *qkv = buffers[0];
840
841
842
843
844
845
    void *bias = buffers[1];
    void *softmax_aux = buffers[2];
    void *rng_state = buffers[3];
    void *output = buffers[4];
    void *doutput = buffers[5];
    void *cu_seqlens = buffers[6];
846
847

    // output
848
849
    void *dqkv = buffers[7];
    void *dbias = buffers[8];
850
851
852
853
854
855

    auto batch = descriptor.batch;
    auto num_head = descriptor.num_head;
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
    auto head_dim = descriptor.head_dim;
856
857
858
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
859
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
860
861
862
863
864
865
866
867
868
869

    NVTE_CHECK(q_max_seqlen == kv_max_seqlen,
               "q_max_seqlen should be equal to kv_max_seqlen in the self attention.");

    auto dtype = descriptor.dtype;
    auto qkv_shape = std::vector<size_t>{batch * q_max_seqlen, 3, num_head, head_dim};
    auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
    auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};

    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
870
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
871
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
872
    // F16 doesn't use this tensor
873
874
875
876
877
878
879
880
881
882
883
884
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);

    auto cu_seqlens_tensor =
        TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);

    // TODO: needs to think about how to pass aux_output_tensors
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

885
    aux_output_tensors.size = 3;
886
887
    auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
    output_s->data.dptr = softmax_aux;
888
889
890
891
    auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
    rng_state_tensor->data.shape = std::vector<size_t>{2};
    rng_state_tensor->data.dtype = DType::kInt64;
    rng_state_tensor->data.dptr = rng_state;
892
893
    auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
    bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
894
895
896
897

    TensorWrapper query_workspace_tensor;

    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
898
899
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
900
901
                                  &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
902
903
                                  dropout_probability, qkv_layout, bias_type, mask_type,
                                  query_workspace_tensor.data(), stream);
904

905
    size_t workspace_size = query_workspace_tensor.shape().data[0];
906
    auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
907
908
909
910
    auto workspace_tensor =
        TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
911
912
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
913
914
                                  &aux_output_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  cu_seqlens_tensor.data(), q_max_seqlen, descriptor.scaling_factor,
915
916
                                  dropout_probability, qkv_layout, bias_type, mask_type,
                                  workspace_tensor.data(), stream);
917
918
919
920

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

921
922
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
923
924
925
926
927
928
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input
    void *q = buffers[0];
    void *kv = buffers[1];
929
930
931
932
    void *bias = buffers[2];
    void *q_cu_seqlens = buffers[3];
    void *kv_cu_seqlens = buffers[4];
    void *seed = buffers[5];
933
934

    // output
935
936
937
    void *output = buffers[6];
    void *softmax_aux = buffers[7];
    void *rng_state = buffers[8];
938
939
940
941
942
943

    auto batch = descriptor.batch;
    auto num_head = descriptor.num_head;
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
    auto head_dim = descriptor.head_dim;
944
945
946
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
947
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
948
949
950
951
952
953

    auto dtype = descriptor.dtype;
    auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
    auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
    auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};

954
    // input tensors
955
956
957
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);

958
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
959
960
961
962
963

    auto q_cu_seqlens_tensor =
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
964

965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
    // output tensors
    auto o_tensor =
        TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);

    // aux tensors

    // F16 doesn't use s_tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);

    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
        mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
980
981
982
983
984
985
986
987
988

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;

    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
989
        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
990
        descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
991
992
993
994
995
        query_workspace_tensor.data(), stream);

    auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
    output_s->data.dptr = softmax_aux;

996
997
    auto workspace_size = query_workspace_tensor.shape().data[0];
    auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
998
999
1000
1001
1002
1003
1004
    auto workspace_tensor =
        TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
1005
        descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
1006
1007
1008
1009
1010
        workspace_tensor.data(), stream);

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

1011
1012
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                            size_t opaque_len) {
1013
1014
1015
1016
1017
1018
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input
    void *q = buffers[0];
    void *kv = buffers[1];
1019
1020
1021
1022
1023
1024
1025
    void *bias = buffers[2];
    void *softmax_aux = buffers[3];
    void *rng_state = buffers[4];
    void *output = buffers[5];
    void *doutput = buffers[6];
    void *q_cu_seqlens = buffers[7];
    void *kv_cu_seqlens = buffers[8];
1026
1027

    // output
1028
1029
1030
    void *dq = buffers[9];
    void *dkv = buffers[10];
    void *dbias = buffers[11];
1031
1032
1033
1034
1035
1036

    auto batch = descriptor.batch;
    auto num_head = descriptor.num_head;
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
    auto head_dim = descriptor.head_dim;
1037
1038
1039
1040
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
1041
1042
1043
1044
1045
1046
1047
1048
1049

    auto dtype = descriptor.dtype;
    auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
    auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
    auto output_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
    auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};

    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
1050
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
1051
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
1052
    // F16 doesn't use this tensor
1053
1054
1055
1056
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
    auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
1057
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067

    auto q_cu_seqlens_tensor =
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);

    // TODO(rewang): need to think about how to pass aux_output_tensors
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

1068
    aux_output_tensors.size = 3;
1069
1070
    auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
    output_s->data.dptr = softmax_aux;
1071
1072
1073
1074
1075
1076
    auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
    rng_state_tensor->data.shape = std::vector<size_t>{2};
    rng_state_tensor->data.dtype = DType::kInt64;
    rng_state_tensor->data.dptr = rng_state;
    auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
    bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
1077
1078
1079
1080
1081
1082
1083
1084
1085

    TensorWrapper query_workspace_tensor;

    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for FP16/BF16
        s_tensor.data(),  // not used for FP16/BF16
        &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
1086
1087
        descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
        mask_type, query_workspace_tensor.data(), stream);
1088

1089
    size_t workspace_size = query_workspace_tensor.shape().data[0];
1090
    auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100

    auto workspace_tensor =
        TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());

    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for FP16/BF16
        s_tensor.data(),  // not used for FP16/BF16
        &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
1101
1102
        descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
        mask_type, workspace_tensor.data(), stream);
1103
1104
1105
1106

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

1107
1108
}  // namespace jax
}  // namespace transformer_engine