modules.cpp 71 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/modules.h"

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
12
#include <cudnn.h>
13
14
15
16
17
18

#include <stdexcept>
#include <string>
#include <vector>

#include "common/common.h"
19
#include "common/util/logging.h"
20
21
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
22
#include "transformer_engine/fused_attn.h"
23
24
25
26
27
28
29
30
31
32
33
34
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"

namespace transformer_engine {
namespace jax {

inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }

35
36
37
38
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
    return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
    auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
    return pybind11::bytes(str);
}

template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
    if (opaque_len != sizeof(T)) {
        throw std::runtime_error("Invalid opaque object size");
    }
    return reinterpret_cast<const T *>(opaque);
}

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype) {
    CustomCallCommonDescriptor desc;
    desc.shape.from_vector(shape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    return PackOpaque(desc);
}

62
63
64
65
66
67
68
69
70
71
72
73
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype) {
    CustomCallCommonWkDescriptor desc;
    desc.shape.from_vector(shape);
    desc.wkshape.from_vector(wkshape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    desc.wk_dtype = wk_dtype;
    return PackOpaque(desc);
}

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
    CustomCallNormDescriptor desc;
    desc.batch_size = batch_size;
    desc.hidden_size = hidden_size;
    desc.wkspace_size = wkspace_size;
    desc.barrier_size = barrier_size;
    desc.dgamma_part_shape.from_vector(dgamma_part_shape);
    desc.dbeta_part_shape.from_vector(dbeta_part_shape);
    desc.x_dtype = x_dtype;
    desc.w_dtype = w_dtype;
    desc.wkspace_dtype = wkspace_dtype;
    desc.barrier_dtype = barrier_dtype;
    desc.dgamma_part_dtype = dgamma_part_dtype;
    desc.dbeta_part_dtype = dbeta_part_dtype;
    desc.zero_centered_gamma = zero_centered_gamma;
    desc.eps = eps;
    desc.sm_margin = sm_margin;
    return PackOpaque(desc);
96
97
}

98
99
100
101
102
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor) {
    return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
                                        dtype, scale_factor});
103
104
}

105
pybind11::bytes PackCustomCallFusedAttnDescriptor(
106
107
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
108
109
110
    size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training) {
zlsh80826's avatar
zlsh80826 committed
111
    return PackOpaque(CustomCallFusedAttnDescriptor{
112
        input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
113
114
        bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type,
        mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
115
116
}

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
    auto input_shape = std::vector<size_t>{rows, cols};
    auto output_shape = std::vector<size_t>{cols, rows};

    auto input_tensor = TensorWrapper(input, input_shape, dtype);
    auto transposed_tensor = TensorWrapper(output, output_shape, dtype);

    nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    void *input = buffers[0];
    void *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto rows = desc.shape.dims[0];
    auto cols = desc.shape.dims[1];
    assert(desc.in_dtype == desc.out_dtype);
    auto dtype = desc.out_dtype;

    TransposeImpl(input, rows, cols, dtype, stream, output);
}

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *input_cast = buffers[4];
    auto *input_cast_trans = buffers[5];
    float *amax_out = reinterpret_cast<float *>(buffers[6]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto input_trans_shape = std::vector<size_t>{n, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto input_cast_tensor =
        TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
                                                 desc.out_dtype, amax_out, scale, scale_inv);

    nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
                        input_cast_trans_tensor.data(), stream);
}

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
              cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
    auto input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));

    auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                       scale, scale_inverse);

    nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
}

void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}

void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
             output);
}

void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

    nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}

pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto gelu_input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
    auto dbias_shape = std::vector<size_t>{hidden_size};

    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
    auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
    auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);

    TensorWrapper dummy_workspace;

    nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
                                    output_tensor.data(), output_trans_tensor.data(),
                                    dbias_tensor.data(), dummy_workspace.data(), nullptr);

    auto work_shape = MakeShapeVector(dummy_workspace.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}

void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    auto *dbias = buffers[7];
    float *amax_out = reinterpret_cast<float *>(buffers[8]);
    void *workspace_ptr = buffers[9];

    const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};
    auto output_trans_shape = std::vector<size_t>{n, m};
    auto dbias_shape = std::vector<size_t>{n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

    auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

    nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
                                    output_tensor.data(), output_trans_tensor.data(),
                                    dbias_tensor.data(), workspace.data(), stream);
}

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
                   cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
    auto input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));

    auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                       scale, scale_inverse);

    nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}

void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
                  output);
}

void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
                  output);
}

void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

    nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}

void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    float *amax_out = reinterpret_cast<float *>(buffers[7]);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = desc.shape.to_vector();
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};
    auto output_trans_shape = std::vector<size_t>{n * 2, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(),
                               output_trans_tensor.data(), stream);
}

406
407
408
409
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps) {
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};

    // empty tensor wrappers are okay just to get workspace size
    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);

    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
427

428
429
430
431
432
433
434
435
436
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
    } else {
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
                         rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
    }
437

438
439
440
441
    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
442
443
}

444
445
446
447
448
449
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
                          size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
                          DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
                          DType out_dtype, void *workspace, DType work_dtype, void *barrier,
                          DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
                          float *scale_inv, cudaStream_t stream) {
450
451
452
453
454
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto workspace_shape = std::vector<size_t>{workspace_size};
    auto barrier_shape = std::vector<size_t>{barrier_size};
455
456
457
458
459
460
461
462
463
464
465
    auto is_layer_norm = (bias) ? true : false;

    auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);

    // assume output dtype = input dtype
    // If we need mixed I/O precision in the future, we need an additional
    // parameter for output type
    auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);

466
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
467
468
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;

469
470
471
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);

472
473
474
475
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);

476
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
477
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
478
                           num_sm, workspace_tensor.data(), barrier_tensor.data());
479
    } else {
480
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
481
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
482
483
                         rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
484
    }
485
}
486

487
488
489
490
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps) {
491
492
493
494
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto intermediates_dtype = DType::kFloat32;
495

496
497
498
499
500
501
502
    // empty tensor wrappers are okay just to get workspace size
    auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
    auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
503

504
505
506
507
508
    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
509

510
511
    // initialize dBeta information here -- layernorm will modify but RMSnorm will not
    std::vector<size_t> dbeta_part_shape;
512
    if (is_layer_norm) {
513
514
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
515

516
517
518
519
520
521
522
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(),
                           dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());

        dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
523
    } else {
524
525
526
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
527
528
                         dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
529
530

        dbeta_part_shape = std::vector<size_t>{0, 0};
531
    }
532
533
534
535
536
537
538
539

    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
                                std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
                                std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
540
541
}

542
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
543
                           size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
544
545
546
547
548
                           bool zero_centered_gamma, float eps, void *input, DType in_dtype,
                           void *weight, DType w_dtype, void *ograd, void *workspace,
                           DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
                           void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
                           DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
549
550
551
552
                           cudaStream_t stream) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    auto intermediates_dtype = DType::kFloat32;
    auto is_layer_norm = (dbeta) ? true : false;

    // assume input type = output type
    auto *grad_output = ograd;
    auto x_dtype = in_dtype;
    auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);

    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);

    auto *x = input;
    auto x_tensor = TensorWrapper(x, input_shape, x_dtype);

    auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
    auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);

570
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
571
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
572

573
574
575
576
    auto workspace_shape = std::vector<size_t>{wkspace_size};
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
    auto barrier_shape = std::vector<size_t>{barrier_size};
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
577
578
    auto dgamma_part_tensor =
        TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
579
580
581
582

    if (is_layer_norm) {
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
583
584
        auto dbeta_part_tensor =
            TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
585

586
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
587
588
589
590
591
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
                           dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                           barrier_tensor.data());
    } else {
592
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
                         dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
    }
}

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *amax = reinterpret_cast<float *>(buffers[3]);
    auto *scale = reinterpret_cast<float *>(buffers[4]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
    auto *output = buffers[6];
    auto *mu = buffers[7];
    auto *rsigma = buffers[8];
    auto *amax_out = buffers[9];
612
613
    auto *workspace = buffers[10];
    auto *barrier = buffers[11];
614
615
616
    assert(amax_out == amax);

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
617
618
619
620
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
621
622
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
623
624
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
625
    auto eps = desc.eps;
626
    auto zero_centered_gamma = desc.zero_centered_gamma;
627
    auto sm_margin = desc.sm_margin;
628
629
630

    auto out_dtype = DType::kFloat8E4M3;

631
632
633
634
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
635
636
637
638
639
640
641
642
643
}

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *output = buffers[3];
    auto *mu = buffers[4];
    auto *rsigma = buffers[5];
644
645
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
646
647
648
649
650
651

    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
652
653
654
655
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
656
657
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
658
659
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
660
661
    auto eps = desc.eps;
    auto out_dtype = in_dtype;
662
    auto zero_centered_gamma = desc.zero_centered_gamma;
663
    auto sm_margin = desc.sm_margin;
664

665
666
667
668
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
669
670
671
672
673
}

void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);

674
675
676
677
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
678
679
    auto dgamma_part_shape = desc.dgamma_part_shape;
    auto dbeta_part_shape = desc.dbeta_part_shape;
680
681
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
682
683
684
685
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = desc.dbeta_part_dtype;
686
    auto eps = desc.eps;
687
    auto zero_centered_gamma = desc.zero_centered_gamma;
688
    auto sm_margin = desc.sm_margin;
689
690
691
692
693
694
695
696
697

    auto *ograd = buffers[0];
    auto *mu = buffers[1];
    auto *rsigma = buffers[2];
    auto *input = buffers[3];
    auto *weight = buffers[4];
    auto *xgrad = buffers[5];
    auto *wgrad = buffers[6];
    auto *dbeta = buffers[7];
698
699
700
701
702
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
    auto *dgamma_part = buffers[10];
    auto *dbeta_part = buffers[11];

703
704
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
                          dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
705
706
707
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
708
709
710
711
712
713
714
715
716
717
718
}

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *amax = reinterpret_cast<float *>(buffers[2]);
    auto *scale = reinterpret_cast<float *>(buffers[3]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *rsigma = buffers[6];
    auto *amax_out = buffers[7];
719
720
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
721
722
723
724
725
726
    assert(amax_out == amax);

    void *bias = nullptr;
    void *mu = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
727
728
729
730
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
731
732
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
733
734
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
735
    auto eps = desc.eps;
736
    auto zero_centered_gamma = desc.zero_centered_gamma;
737
    auto sm_margin = desc.sm_margin;
738
739
    auto out_dtype = DType::kFloat8E4M3;

740
741
742
743
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
744
745
746
747
748
749
750
}

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *output = buffers[2];
    auto *rsigma = buffers[3];
751
752
    auto *workspace = buffers[4];
    auto *barrier = buffers[5];
753
754
755
756
757
758
759
760

    void *bias = nullptr;
    void *mu = nullptr;
    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
761
762
763
764
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
765
766
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
767
768
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
769
    auto eps = desc.eps;
770
    auto zero_centered_gamma = desc.zero_centered_gamma;
771
    auto sm_margin = desc.sm_margin;
772
773
    auto out_dtype = in_dtype;

774
775
776
777
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
778
779
780
781
782
783
784
785
786
}

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *ograd = buffers[0];
    auto *rsigma = buffers[1];
    auto *input = buffers[2];
    auto *weight = buffers[3];
    auto *xgrad = buffers[4];
    auto *wgrad = buffers[5];
787
788
789
790
791
792
793
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
    auto *dgamma_part = buffers[8];

    void *mu = nullptr;
    void *dbeta = nullptr;
    void *dbeta_part = nullptr;
794
795

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
796
797
798
799
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
800
801
802
    auto dgamma_part_shape = desc.dgamma_part_shape;
    Shape dbeta_part_shape;
    dbeta_part_shape.from_vector({0, 0});
803
804
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
805
806
807
808
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = DType::kByte;
809
    auto eps = desc.eps;
810
    auto zero_centered_gamma = desc.zero_centered_gamma;
811

812
813
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
                          dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
814
815
816
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
}

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    auto *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);

    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);

    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);

    nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
859
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);
    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
                                stream);
}

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
876
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                                 dgrad_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                size_t opaque_len) {
    auto *input = buffers[0];
    auto *mask = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
894
895
    auto io_shape =
        std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
896
    auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, io_shape, dtype);
    // Mask would be casted to uint8_t
    auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
    auto output_tensor = TensorWrapper(output, io_shape, dtype);

    nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
                                       output_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 size_t opaque_len) {
    // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
    ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
920
    auto attn_batch = desc.batch_size * desc.head_dim;
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);

    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                    desc.scale_factor, stream);
}

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
939
    auto attn_batch = desc.batch_size * desc.head_dim;
940
941
942
943
944
945
946
947
948
949
950
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_backward(
        grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
        desc.scale_factor, stream);
}
951

952
953
954
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
955
                                            size_t q_attn_heads, size_t kv_attn_heads,
956
957
958
959
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim) {
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
960
        mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
zlsh80826's avatar
zlsh80826 committed
961
        head_dim);
962
963
964
    return backend;
}

965
966
967
968
969
970
/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
971
972
973
974
975
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
                                       const CustomCallFusedAttnDescriptor *desc,
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
976
977
978
979
    auto input_batch = desc->input_batch;
    auto bias_batch = desc->bias_batch;
    auto attn_heads = desc->attn_heads;
    auto bias_heads = desc->bias_heads;
980
981
982
983
984
985
986
987
    auto q_max_seqlen = desc->q_max_seqlen;
    auto kv_max_seqlen = desc->kv_max_seqlen;

    // all backends need softmax but expect different shapes/dtypes
    // start with the max512 sequence length softmax shape/dtype and correct later
    tensor_pack->size = 1;
    Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
    softmax_aux->data.dptr = softmax_buf;
988
    softmax_aux->data.shape =
989
        std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    softmax_aux->data.dtype = desc->dtype;

    // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
        tensor_pack->size = 2;
        Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
        rng_state_aux->data.dptr = rng_state_buf;
        rng_state_aux->data.shape = std::vector<size_t>{2};
        rng_state_aux->data.dtype = DType::kInt64;
        // correct softmax shape/dtype
        softmax_aux->data.shape.at(3) = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
        softmax_aux->data.dtype = DType::kFloat32;

        // include bias if enabled
        if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
            tensor_pack->size = 3;
            Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
            bias_aux->data.dptr = bias_buf;
1008
1009
            bias_aux->data.shape =
                std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
            bias_aux->data.dtype = desc->dtype;
        }
    }
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
1023
1024
1025
1026
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
                                        const CustomCallFusedAttnDescriptor *desc,
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
1027
1028
1029
1030
1031
1032
1033
1034
1035
    // Backward calls put everything into the tensor pack for every backend
    // so we set dummy bias_type and backend choices here to follow the correct code path
    auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
    auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
    PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend,
                                      softmax_buf, rng_state_buf, bias_buf);

    // correct softmax shape for max512 sequence length kernel
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
1036
        Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
1037
1038
1039
1040
1041
        softmax_aux->data.shape.at(3) = desc->kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
        softmax_aux->data.dtype = desc->dtype;
    }
}

1042
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
1043
1044
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
1045
1046
1047
1048
1049
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
    // For qkv_packed
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
1050

1051
    // For kv_packed
1052
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
1053
    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
1054
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
1055
1056
    auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

1057
1058
1059
1060
1061
1062
1063
    // For separate q, k, v
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_shape = k_shape;
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1064
1065
    auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1066
    // F16 doesn't use this tensor
1067
1068
1069
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
    auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

1070
    auto q_cu_seqlens_tensor =
1071
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1072
    auto kv_cu_seqlens_tensor =
1073
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1074
1075
1076
1077
1078
1079
1080

    auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        assert(q_max_seqlen == kv_max_seqlen);
        nvte_fused_attn_fwd_qkvpacked(
            qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
            &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(),
            q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
            mask_type, query_workspace_tensor.data(), nullptr);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
                                     s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                                     q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                                     dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
                                     is_training, scaling_factor, dropout_probability, qkv_layout,
                                     bias_type, mask_type, query_workspace_tensor.data(), nullptr);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                            s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                            dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
                            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                            query_workspace_tensor.data(), nullptr);
    } else {
        NVTE_ERROR("Unsupported QKVLayout.");
    }
1105

1106
1107
    auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
1108
}
1109

1110
1111
1112
1113
1114
1115
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
    size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
    bool is_training) {
    auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
1116
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
1117
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
1118

1119
    auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
1120
1121
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1122
1123
    // F16 doesn't use s_tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
1124

1125
    auto q_cu_seqlens_tensor =
1126
        TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
1127
    auto kv_cu_seqlens_tensor =
1128
        TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
1129
1130
1131

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
1132
1133
1134

    TensorWrapper query_workspace_tensor;

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        assert(q_max_seqlen == kv_max_seqlen);
        auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
        auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
        nvte_fused_attn_bwd_qkvpacked(
            qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
            q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            query_workspace_tensor.data(), nullptr);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto kv_shape =
            std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
        auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
        nvte_fused_attn_bwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            query_workspace_tensor.data(), nullptr);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
        auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
        auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
        auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
        nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                            doutput_tensor.data(),
                            s_tensor.data(),  // not used for F16
                            s_tensor.data(),  // not used for F16
                            &aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
                            dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                            kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
                            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                            query_workspace_tensor.data(), nullptr);
    } else {
        NVTE_ERROR("Unsupported QKVLayout.");
    }

    auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
1188
1189
1190
1191
1192
1193
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1194
1195
    /* Input buffers from XLA */
    /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
1196
1197
1198
1199
1200
    void *bias = buffers[3];
    void *q_cu_seqlens = buffers[4];
    void *kv_cu_seqlens = buffers[5];
    void *seed = buffers[6];

1201
    /* Output buffer from XLA */
1202
1203
1204
1205
1206
    void *output = buffers[7];
    void *softmax_aux = buffers[8];
    void *rng_state = buffers[9];
    void *workspace = buffers[10];

1207
    /* Descriptor */
1208
1209
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1210
1211
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1212
    auto attn_heads = descriptor.attn_heads;
1213
    auto num_gqa_groups = descriptor.num_gqa_groups;
1214
    auto bias_heads = descriptor.bias_heads;
1215
1216
1217
1218
1219
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1220
1221
    auto qkv_layout = descriptor.qkv_layout;
    auto dtype = descriptor.dtype;
1222

1223
    /* Input tensors */
1224
1225
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1226
    auto v_shape = k_shape;
1227
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1228
1229
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

1230
    /* Output tensors */
1231
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
1232
1233
    auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto o_tensor = TensorWrapper(output, o_shape, dtype);
1234
    auto q_cu_seqlens_tensor =
1235
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1236
    auto kv_cu_seqlens_tensor =
1237
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1238

1239
    /* Prepare RNG state */
1240
1241
1242
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1243
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1244
1245
1246
        head_dim);
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

1247
    /* Auxiliary tensors (to be propagated to the backward pass later) */
1248
1249
1250
1251
1252
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);
    PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                      softmax_aux);

1253
    /* cuDNN workspace */
1254
1255
1256
    auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                          descriptor.wkspace_dtype);

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
    /* Call the underly NVTE API */
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        auto qkv = buffers[0];
        auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
        nvte_fused_attn_fwd_qkvpacked(
            qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
            &aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
            descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
            bias_type, mask_type, workspace_tensor.data(), stream);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto kv = buffers[1];
        auto kv_shape =
            std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
        nvte_fused_attn_fwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
            &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
            rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            workspace_tensor.data(), stream);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto k = buffers[1];
        auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(k, k_shape, dtype);
        auto v = buffers[2];
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(v, v_shape, dtype);
        nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                            s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                            rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
                            descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
                            bias_type, mask_type, workspace_tensor.data(), stream);
    } else {
        NVTE_ERROR("Unsupported qkv_layout.");
    }
1300
1301
1302
1303
1304

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
1305
1306
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
1307
1308
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
1309
1310
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1311
    auto v_shape = k_shape;
1312
1313
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328

    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

    auto q_cu_seqlens_tensor =
1329
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1330
    auto kv_cu_seqlens_tensor =
1331
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                        kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
                        dropout_probability, qkv_layout, bias_type, mask_type,
                        query_workspace_tensor.data(), nullptr);

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1355
1356
    /* Input buffers from XLA */
    /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
1357
1358
1359
1360
1361
1362
1363
1364
    void *bias = buffers[3];
    void *softmax_aux = buffers[4];
    void *rng_state = buffers[5];
    void *output = buffers[6];
    void *doutput = buffers[7];
    void *q_cu_seqlens = buffers[8];
    void *kv_cu_seqlens = buffers[9];

1365
1366
    /* Output buffer from XLA */
    /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */
1367
1368
1369
    void *dbias = buffers[13];
    void *workspace = buffers[14];

1370
    /* Descriptor */
1371
1372
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1373
1374
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1375
    auto attn_heads = descriptor.attn_heads;
1376
    auto num_gqa_groups = descriptor.num_gqa_groups;
1377
    auto bias_heads = descriptor.bias_heads;
1378
1379
1380
1381
1382
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1383
1384
    auto qkv_layout = descriptor.qkv_layout;
    auto dtype = descriptor.dtype;
1385

1386
    /* Input tensors */
1387
1388
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1389
1390
1391
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

1392
    /* Output tensors */
1393
1394
1395
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
    auto q_cu_seqlens_tensor =
1396
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1397
    auto kv_cu_seqlens_tensor =
1398
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1399

1400
    /* Auxiliary tensors (propagated from the forward pass) */
1401
1402
1403
1404
    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1405
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1406
1407
1408
1409
        head_dim);
    PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                       rng_state, bias);

1410
    /* cuDNN workspace */
1411
1412
1413
1414
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);

1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
    /* Call the underly NVTE API */
    if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
        auto qkv = buffers[0];
        auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
        auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
        auto dqkv = buffers[10];
        auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
        nvte_fused_attn_bwd_qkvpacked(
            qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
            q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            workspace_tensor.data(), stream);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto kv = buffers[1];
        auto kv_shape =
            std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
        auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
        auto dq = buffers[10];
        auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
        auto dkv = buffers[11];
        auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
        nvte_fused_attn_bwd_kvpacked(
            q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
            s_tensor.data(),  // not used for F16
            s_tensor.data(),  // not used for F16
            &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
            q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
            workspace_tensor.data(), stream);
    } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
        auto q = buffers[0];
        auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
        auto q_tensor = TensorWrapper(q, q_shape, dtype);
        auto k = buffers[1];
        auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
        auto k_tensor = TensorWrapper(k, k_shape, dtype);
        auto v = buffers[2];
        auto v_shape = k_shape;
        auto v_tensor = TensorWrapper(v, v_shape, dtype);
        auto dq = buffers[10];
        auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
        auto dk = buffers[11];
        auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
        auto dv = buffers[12];
        auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
        nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                            doutput_tensor.data(),
                            s_tensor.data(),  // not used for F16
                            s_tensor.data(),  // not used for F16
                            &aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
                            dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                            kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
                            scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                            workspace_tensor.data(), stream);
    } else {
        NVTE_ERROR("Unsupported qkv_layout.");
    }
1477
1478
1479
1480

    nvte_tensor_pack_destroy(&aux_input_tensors);
}

1481
1482
}  // namespace jax
}  // namespace transformer_engine