modules.cpp 83.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/modules.h"

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
12
#include <cudnn.h>
13
14
15
16
17
18
19
20

#include <functional>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>

#include "common/common.h"
21
#include "common/util/cuda_runtime.h"
22
23
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
24
#include "transformer_engine/fused_attn.h"
25
26
27
28
29
30
31
32
33
34
35
36
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"

namespace transformer_engine {
namespace jax {

inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }

37
38
39
40
std::vector<size_t> MakeShapeVector(NVTEShape shape) {
    return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
    auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
    return pybind11::bytes(str);
}

template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
    if (opaque_len != sizeof(T)) {
        throw std::runtime_error("Invalid opaque object size");
    }
    return reinterpret_cast<const T *>(opaque);
}

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype) {
    CustomCallCommonDescriptor desc;
    desc.shape.from_vector(shape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    return PackOpaque(desc);
}

64
65
66
67
68
69
70
71
72
73
74
75
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype) {
    CustomCallCommonWkDescriptor desc;
    desc.shape.from_vector(shape);
    desc.wkshape.from_vector(wkshape);
    desc.in_dtype = in_dtype;
    desc.out_dtype = out_dtype;
    desc.wk_dtype = wk_dtype;
    return PackOpaque(desc);
}

76
77
78
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
                                             size_t wkspace_size, size_t barrier_size,
                                             size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
79
80
81
82
83
84
85
86
                                             DType x_dtype, DType w_dtype, DType wkspace_dtype,
                                             DType barrier_dtype, DType dgamma_part_dtype,
                                             DType dbeta_part_dtype, bool zero_centered_gamma,
                                             float eps, int sm_margin) {
    return PackOpaque(CustomCallNormDescriptor{
        batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
        x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
        zero_centered_gamma, eps, sm_margin});
87
88
}

89
90
91
92
93
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor) {
    return PackOpaque(SoftmaxDescriptor{batch_size, padding_size, head_dim, q_seqlen, k_seqlen,
                                        dtype, scale_factor});
94
95
}

96
pybind11::bytes PackCustomCallFusedAttnDescriptor(
97
98
99
100
101
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    size_t wkspace_size, float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    DType dtype, DType wkspace_dtype, bool is_training) {
zlsh80826's avatar
zlsh80826 committed
102
    return PackOpaque(CustomCallFusedAttnDescriptor{
103
104
105
        input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
        bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability,
        bias_type, mask_type, dtype, wkspace_dtype, is_training});
106
107
}

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
    auto input_shape = std::vector<size_t>{rows, cols};
    auto output_shape = std::vector<size_t>{cols, rows};

    auto input_tensor = TensorWrapper(input, input_shape, dtype);
    auto transposed_tensor = TensorWrapper(output, output_shape, dtype);

    nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    void *input = buffers[0];
    void *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto rows = desc.shape.dims[0];
    auto cols = desc.shape.dims[1];
    assert(desc.in_dtype == desc.out_dtype);
    auto dtype = desc.out_dtype;

    TransposeImpl(input, rows, cols, dtype, stream, output);
}

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *input_cast = buffers[4];
    auto *input_cast_trans = buffers[5];
    float *amax_out = reinterpret_cast<float *>(buffers[6]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto input_trans_shape = std::vector<size_t>{n, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto input_cast_tensor =
        TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
                                                 desc.out_dtype, amax_out, scale, scale_inv);

    nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
                        input_cast_trans_tensor.data(), stream);
}

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
              cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
    auto input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));

    auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                       scale, scale_inverse);

    nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
}

void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output);
}

void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
             output);
}

void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

    nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}

pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto gelu_input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_shape = std::vector<size_t>{batch_size, hidden_size};
    auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
    auto dbias_shape = std::vector<size_t>{hidden_size};

    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
    auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
    auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);

    TensorWrapper dummy_workspace;

    nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
                                    output_tensor.data(), output_trans_tensor.data(),
                                    dbias_tensor.data(), dummy_workspace.data(), nullptr);

    auto work_shape = MakeShapeVector(dummy_workspace.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}

void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    auto *dbias = buffers[7];
    float *amax_out = reinterpret_cast<float *>(buffers[8]);
    void *workspace_ptr = buffers[9];

    const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n};
    auto output_shape = std::vector<size_t>{m, n};
    auto output_trans_shape = std::vector<size_t>{n, m};
    auto dbias_shape = std::vector<size_t>{n};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

    auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

    nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(),
                                    output_tensor.data(), output_trans_tensor.data(),
                                    dbias_tensor.data(), workspace.data(), stream);
}

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
                   cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
    auto input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n};

    auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));

    auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                       scale, scale_inverse);

    nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}

void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
                  output);
}

void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    float *amax = reinterpret_cast<float *>(buffers[1]);
    float *scale = reinterpret_cast<float *>(buffers[2]);
    float *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    float *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];

    GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
                  output);
}

void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = std::vector<size_t>{m, n};
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);

    nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}

void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len) {
    auto *input = buffers[0];
    auto *gelu_input = buffers[1];
    float *amax = reinterpret_cast<float *>(buffers[2]);
    float *scale = reinterpret_cast<float *>(buffers[3]);
    float *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *output_trans = buffers[6];
    float *amax_out = reinterpret_cast<float *>(buffers[7]);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    assert(amax == amax_out);
    if (!use_fp8(desc.out_dtype)) {
        scale = nullptr;
        scale_inv = nullptr;
        amax_out = nullptr;
    }
    auto m = desc.shape.dims[0];
    auto n = desc.shape.dims[1];
    auto input_shape = desc.shape.to_vector();
    auto gelu_input_shape = std::vector<size_t>{m, n * 2};
    auto output_shape = std::vector<size_t>{m, n * 2};
    auto output_trans_shape = std::vector<size_t>{n * 2, m};

    auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
    auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
    auto output_tensor =
        TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
    auto output_trans_tensor =
        TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(),
                               output_trans_tensor.data(), stream);
}

397
398
399
400
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps) {
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};

    // empty tensor wrappers are okay just to get workspace size
    auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
    auto output_tensor = TensorWrapper(nullptr, input_shape, out_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);

    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
418

419
420
421
422
423
424
425
426
427
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());
    } else {
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
                         rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
    }
428

429
430
431
432
    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()));
433
434
}

435
436
437
438
439
440
void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size,
                          size_t barrier_size, bool zero_centered_gamma, float eps, void *input,
                          DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
                          DType out_dtype, void *workspace, DType work_dtype, void *barrier,
                          DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
                          float *scale_inv, cudaStream_t stream) {
441
442
443
444
445
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto workspace_shape = std::vector<size_t>{workspace_size};
    auto barrier_shape = std::vector<size_t>{barrier_size};
446
447
448
449
450
451
452
453
454
455
456
    auto is_layer_norm = (bias) ? true : false;

    auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);

    // assume output dtype = input dtype
    // If we need mixed I/O precision in the future, we need an additional
    // parameter for output type
    auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);

457
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
458
459
    auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;

460
461
462
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);

463
464
465
466
    if (is_layer_norm) {
        auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);

467
        layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
468
                           output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
469
                           num_sm, workspace_tensor.data(), barrier_tensor.data());
470
    } else {
471
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
472
        nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
473
474
                         rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
475
    }
476
}
477

478
479
480
481
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps) {
482
483
484
485
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
    auto intermediates_dtype = DType::kFloat32;
486

487
488
489
490
491
492
493
    // empty tensor wrappers are okay just to get workspace size
    auto dz_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
    auto x_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
    auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
494

495
496
497
498
499
    // dummy tensor wrappers that will carry workspace size info later
    TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
    TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
500

501
502
    // initialize dBeta information here -- layernorm will modify but RMSnorm will not
    std::vector<size_t> dbeta_part_shape;
503
    if (is_layer_norm) {
504
505
        auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
506

507
508
509
510
511
512
513
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(),
                           dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), nullptr,
                           num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data());

        dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape());
514
    } else {
515
516
517
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
518
519
                         dummy_dgamma_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(),
                         dummy_barrier_tensor.data());
520
521

        dbeta_part_shape = std::vector<size_t>{0, 0};
522
    }
523
524
525
526
527
528
529
530

    auto work_shape = MakeShapeVector(dummy_work_tensor.shape());
    auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape());
    auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape());
    return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()),
                                std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()),
                                std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()),
                                std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype()));
531
532
}

533
534
535
536
537
538
539
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
                           size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
                           bool zero_centered_gamma, float eps, void *input, DType in_dtype,
                           void *weight, DType w_dtype, void *ograd, void *workspace,
                           DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
                           void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
                           DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype,
540
541
542
543
                           cudaStream_t stream) {
    auto input_shape = std::vector<size_t>{batch_size, hidden_size};
    auto weight_shape = std::vector<size_t>{hidden_size};
    auto intermediates_shape = std::vector<size_t>{batch_size};
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    auto intermediates_dtype = DType::kFloat32;
    auto is_layer_norm = (dbeta) ? true : false;

    // assume input type = output type
    auto *grad_output = ograd;
    auto x_dtype = in_dtype;
    auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);

    auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);

    auto *x = input;
    auto x_tensor = TensorWrapper(x, input_shape, x_dtype);

    auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
    auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
    auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);

561
    auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
562
    auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
563

564
565
566
567
568
569
    auto workspace_shape = std::vector<size_t>{wkspace_size};
    auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
    auto barrier_shape = std::vector<size_t>{barrier_size};
    auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
    auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
    auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
570
571
572
573

    if (is_layer_norm) {
        auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
        auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
574
575
        auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
        auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
576

577
        layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
578
579
580
581
582
                           rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
                           wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
                           dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                           barrier_tensor.data());
    } else {
583
        NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
                         gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
                         dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
                         barrier_tensor.data());
    }
}

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *amax = reinterpret_cast<float *>(buffers[3]);
    auto *scale = reinterpret_cast<float *>(buffers[4]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
    auto *output = buffers[6];
    auto *mu = buffers[7];
    auto *rsigma = buffers[8];
    auto *amax_out = buffers[9];
603
604
    auto *workspace = buffers[10];
    auto *barrier = buffers[11];
605
606
607
    assert(amax_out == amax);

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
608
609
610
611
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
612
613
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
614
615
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
616
    auto eps = desc.eps;
617
    auto zero_centered_gamma = desc.zero_centered_gamma;
618
    auto sm_margin = desc.sm_margin;
619
620
621

    auto out_dtype = DType::kFloat8E4M3;

622
623
624
625
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
626
627
628
629
630
631
632
633
634
}

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *bias = buffers[2];
    auto *output = buffers[3];
    auto *mu = buffers[4];
    auto *rsigma = buffers[5];
635
636
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
637
638
639
640
641
642

    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
643
644
645
646
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
647
648
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
649
650
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
651
652
    auto eps = desc.eps;
    auto out_dtype = in_dtype;
653
    auto zero_centered_gamma = desc.zero_centered_gamma;
654
    auto sm_margin = desc.sm_margin;
655

656
657
658
659
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
660
661
662
663
664
}

void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);

665
666
667
668
669
670
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
    auto *dgamma_part_sizes = desc.dgamma_part_sizes;
    auto *dbeta_part_sizes = desc.dbeta_part_sizes;
671
672
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
673
674
675
676
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = desc.dbeta_part_dtype;
677
    auto eps = desc.eps;
678
    auto zero_centered_gamma = desc.zero_centered_gamma;
679
    auto sm_margin = desc.sm_margin;
680
681
682
683
684
685
686
687
688

    auto *ograd = buffers[0];
    auto *mu = buffers[1];
    auto *rsigma = buffers[2];
    auto *input = buffers[3];
    auto *weight = buffers[4];
    auto *xgrad = buffers[5];
    auto *wgrad = buffers[6];
    auto *dbeta = buffers[7];
689
690
691
692
693
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
    auto *dgamma_part = buffers[10];
    auto *dbeta_part = buffers[11];

694
695
696
697
698
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
                          dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
699
700
701
702
703
704
705
706
707
708
709
}

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *amax = reinterpret_cast<float *>(buffers[2]);
    auto *scale = reinterpret_cast<float *>(buffers[3]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
    auto *output = buffers[5];
    auto *rsigma = buffers[6];
    auto *amax_out = buffers[7];
710
711
    auto *workspace = buffers[8];
    auto *barrier = buffers[9];
712
713
714
715
716
717
    assert(amax_out == amax);

    void *bias = nullptr;
    void *mu = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
718
719
720
721
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
722
723
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
724
725
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
726
    auto eps = desc.eps;
727
    auto zero_centered_gamma = desc.zero_centered_gamma;
728
    auto sm_margin = desc.sm_margin;
729
730
    auto out_dtype = DType::kFloat8E4M3;

731
732
733
734
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
735
736
737
738
739
740
741
}

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *weight = buffers[1];
    auto *output = buffers[2];
    auto *rsigma = buffers[3];
742
743
    auto *workspace = buffers[4];
    auto *barrier = buffers[5];
744
745
746
747
748
749
750
751

    void *bias = nullptr;
    void *mu = nullptr;
    float *amax = nullptr;
    float *scale = nullptr;
    float *scale_inv = nullptr;

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
752
753
754
755
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
756
757
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
758
759
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
760
    auto eps = desc.eps;
761
    auto zero_centered_gamma = desc.zero_centered_gamma;
762
    auto sm_margin = desc.sm_margin;
763
764
    auto out_dtype = in_dtype;

765
766
767
768
    LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
                         eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
                         wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
                         stream);
769
770
771
772
773
774
775
776
777
}

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *ograd = buffers[0];
    auto *rsigma = buffers[1];
    auto *input = buffers[2];
    auto *weight = buffers[3];
    auto *xgrad = buffers[4];
    auto *wgrad = buffers[5];
778
779
780
781
782
783
784
    auto *workspace = buffers[6];
    auto *barrier = buffers[7];
    auto *dgamma_part = buffers[8];

    void *mu = nullptr;
    void *dbeta = nullptr;
    void *dbeta_part = nullptr;
785
786

    const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
787
788
789
790
791
792
    auto batch_size = desc.batch_size;
    auto hidden_size = desc.hidden_size;
    auto wkspace_size = desc.wkspace_size;
    auto barrier_size = desc.barrier_size;
    auto dgamma_part_sizes = desc.dgamma_part_sizes;
    size_t dbeta_part_sizes[2] = {0, 0};
793
794
    auto in_dtype = desc.x_dtype;
    auto w_dtype = desc.w_dtype;
795
796
797
798
    auto wkspace_dtype = desc.wkspace_dtype;
    auto barrier_dtype = desc.barrier_dtype;
    auto dgamma_part_dtype = desc.dgamma_part_dtype;
    auto dbeta_part_dtype = DType::kByte;
799
    auto eps = desc.eps;
800
    auto zero_centered_gamma = desc.zero_centered_gamma;
801

802
803
804
805
806
    LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
                          dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
                          w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
                          rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
                          dbeta_part_dtype, stream);
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
}

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];
    auto *amax_out = reinterpret_cast<float *>(buffers[5]);
    assert(amax == amax_out);

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

    nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    auto *input = buffers[0];
    auto *amax = reinterpret_cast<float *>(buffers[1]);
    auto *scale = reinterpret_cast<float *>(buffers[2]);
    auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
    auto *output = buffers[4];

    const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);

    auto shape = desc.shape.to_vector();
    auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);

    auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);

    nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
849
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);
    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
                                stream);
}

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
866
    auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                                 dgrad_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                size_t opaque_len) {
    auto *input = buffers[0];
    auto *mask = buffers[1];
    auto *output = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
884
885
    auto io_shape =
        std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
886
    auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, io_shape, dtype);
    // Mask would be casted to uint8_t
    auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
    auto output_tensor = TensorWrapper(output, io_shape, dtype);

    nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
                                       output_tensor.data(), desc.scale_factor, stream);
}

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 size_t opaque_len) {
    // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
    ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           size_t opaque_len) {
    auto *input = buffers[0];
    auto *output = buffers[1];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
910
    auto attn_batch = desc.batch_size * desc.head_dim;
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto input_tensor = TensorWrapper(input, shape, dtype);

    auto output_tensor = TensorWrapper(output, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                    desc.scale_factor, stream);
}

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            size_t opaque_len) {
    auto *grad_output = buffers[0];
    auto *softmax_output = buffers[1];
    auto *dgrad = buffers[2];

    const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
929
    auto attn_batch = desc.batch_size * desc.head_dim;
930
931
932
933
934
935
936
937
938
939
940
    auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
    auto dtype = desc.dtype;

    auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
    auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
    auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

    nvte_scaled_upper_triang_masked_softmax_backward(
        grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
        desc.scale_factor, stream);
}
941

942
943
944
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
945
                                            size_t q_num_heads, size_t kv_num_heads,
946
947
948
949
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim) {
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
zlsh80826's avatar
zlsh80826 committed
950
951
        mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen,
        head_dim);
952
953
954
    return backend;
}

955
956
957
958
959
960
/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
961
962
963
964
965
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
                                       const CustomCallFusedAttnDescriptor *desc,
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
966
967
968
969
    auto input_batch = desc->input_batch;
    auto bias_batch = desc->bias_batch;
    auto attn_heads = desc->attn_heads;
    auto bias_heads = desc->bias_heads;
970
971
972
973
974
975
976
977
    auto q_max_seqlen = desc->q_max_seqlen;
    auto kv_max_seqlen = desc->kv_max_seqlen;

    // all backends need softmax but expect different shapes/dtypes
    // start with the max512 sequence length softmax shape/dtype and correct later
    tensor_pack->size = 1;
    Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
    softmax_aux->data.dptr = softmax_buf;
978
    softmax_aux->data.shape =
979
        std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    softmax_aux->data.dtype = desc->dtype;

    // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
        tensor_pack->size = 2;
        Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
        rng_state_aux->data.dptr = rng_state_buf;
        rng_state_aux->data.shape = std::vector<size_t>{2};
        rng_state_aux->data.dtype = DType::kInt64;
        // correct softmax shape/dtype
        softmax_aux->data.shape.at(3) = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
        softmax_aux->data.dtype = DType::kFloat32;

        // include bias if enabled
        if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
            tensor_pack->size = 3;
            Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
            bias_aux->data.dptr = bias_buf;
998
999
            bias_aux->data.shape =
                std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            bias_aux->data.dtype = desc->dtype;
        }
    }
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
1013
1014
1015
1016
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
                                        const CustomCallFusedAttnDescriptor *desc,
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
1017
1018
1019
1020
1021
1022
1023
1024
1025
    // Backward calls put everything into the tensor pack for every backend
    // so we set dummy bias_type and backend choices here to follow the correct code path
    auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
    auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
    PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend,
                                      softmax_buf, rng_state_buf, bias_buf);

    // correct softmax shape for max512 sequence length kernel
    if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
1026
        Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
1027
1028
1029
1030
1031
1032
        softmax_aux->data.shape.at(3) = desc->kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
        softmax_aux->data.dtype = desc->dtype;
    }
}

pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
1033
1034
1035
1036
    size_t input_batch, size_t bias_batch, size_t max_seqlen,
    size_t attn_heads, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
1037
1038
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;

1039
1040
    auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
1041
1042
1043

    auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
    auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
1044
    auto cu_seqlens_tensor =
1045
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1046
    auto o_tensor = TensorWrapper(
1047
        nullptr, std::vector<size_t>{input_batch, max_seqlen, attn_heads, head_dim}, dtype);
1048
1049
1050
1051
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
    auto rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

    auto backend = nvte_get_fused_attn_backend(
1052
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1053
        mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
1054
1055
1056
1057
1058
1059
1060

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
1061
1062
1063
                                  rng_state_tensor.data(), max_seqlen, is_training, scaling_factor,
                                  dropout_probability, qkv_layout, bias_type, mask_type,
                                  query_workspace_tensor.data(), nullptr);
1064
1065
1066
1067
1068

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

1069
1070
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
1071
1072
1073
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1074
    // input buffers from XLA
1075
1076
1077
    void *qkv = buffers[0];
    void *bias = buffers[1];
    void *cu_seqlens = buffers[2];
1078
    void *seed = buffers[3];
1079

1080
    // output buffers from XLA
1081
1082
    void *output = buffers[4];
    void *softmax_aux = buffers[5];
1083
    void *rng_state = buffers[6];
1084
    void *workspace = buffers[7];
1085

1086
    // tensor sizes
1087
1088
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1089
    auto max_seqlen = descriptor.q_max_seqlen;
1090
1091
    auto attn_heads = descriptor.attn_heads;
    auto bias_heads = descriptor.bias_heads;
1092
    auto head_dim = descriptor.head_dim;
1093
1094
1095
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
zlsh80826's avatar
zlsh80826 committed
1096

1097
    auto dtype = descriptor.dtype;
1098
1099
    auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
1100

1101
    // input tensors
1102
1103
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
1104
    auto cu_seqlens_tensor =
1105
        TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1106

1107
    // output tensors
1108
1109
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in FP16/BF16
    auto o_tensor = TensorWrapper(
1110
        output, std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim}, dtype);
1111

1112
1113
    // prep RNG state
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
1114
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
1115
    auto backend = nvte_get_fused_attn_backend(
1116
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1117
        mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
1118
    PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream);
1119

1120
    // auxiliary tensors (to be propagated to the backward pass later)
1121
1122
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);
1123
1124
    PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                      softmax_aux);
1125

1126
1127
1128
1129
    // cuDNN workspace
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
1130

1131
1132
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
1133
                                  rng_state_tensor.data(), max_seqlen, descriptor.is_training,
1134
                                  descriptor.scaling_factor, dropout_probability, qkv_layout,
1135
                                  bias_type, mask_type, workspace_tensor.data(), stream);
1136

1137
1138
    nvte_tensor_pack_destroy(&aux_output_tensors);
}
1139

1140
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
1141
1142
1143
1144
    size_t input_batch, size_t bias_batch, size_t max_seqlen,
    size_t attn_heads, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
1145
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
1146

1147
1148
1149
    auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
    auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
1150

1151
1152
1153
1154
1155
1156
1157
1158
1159
    auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1160
    auto cu_seqlens_tensor =
1161
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
                                  &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
                                  dropout_probability, qkv_layout, bias_type, mask_type,
                                  query_workspace_tensor.data(), nullptr);

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
1177
1178
}

1179
1180
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
1181
1182
1183
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1184
    // input buffers from XLA
1185
    void *qkv = buffers[0];
1186
1187
1188
1189
1190
1191
    void *bias = buffers[1];
    void *softmax_aux = buffers[2];
    void *rng_state = buffers[3];
    void *output = buffers[4];
    void *doutput = buffers[5];
    void *cu_seqlens = buffers[6];
1192

1193
    // output buffers from XLA
1194
1195
    void *dqkv = buffers[7];
    void *dbias = buffers[8];
1196
    void *workspace = buffers[9];
1197

1198
    // tensor sizes
1199
1200
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1201
    auto max_seqlen = descriptor.q_max_seqlen;
1202
1203
    auto attn_heads = descriptor.attn_heads;
    auto bias_heads = descriptor.bias_heads;
1204
    auto head_dim = descriptor.head_dim;
1205
    auto scaling_factor = descriptor.scaling_factor;
1206
1207
1208
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
zlsh80826's avatar
zlsh80826 committed
1209

1210
    auto dtype = descriptor.dtype;
1211
1212
1213
    auto qkv_shape = std::vector<size_t>{input_batch * max_seqlen, 3, attn_heads, head_dim};
    auto output_shape = std::vector<size_t>{input_batch * max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, max_seqlen, max_seqlen};
1214

1215
    // input tensors
1216
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
1217
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
1218
1219
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

1220
1221
    // output tensors
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in FP16/BF16
1222
1223
1224
    auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
    auto cu_seqlens_tensor =
1225
        TensorWrapper(cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1226

1227
1228
1229
1230
1231
    // auxiliary tensors (propagated from the forward pass)
    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;
    auto backend = nvte_get_fused_attn_backend(
1232
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1233
        mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim);
1234
1235
    PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                       rng_state, bias);
1236

1237
1238
1239
1240
    // cuDNN workspace
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
1241
1242

    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
1243
1244
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
1245
1246
                                  &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  cu_seqlens_tensor.data(), max_seqlen, scaling_factor,
1247
                                  dropout_probability, qkv_layout, bias_type, mask_type,
1248
                                  workspace_tensor.data(), stream);
1249

1250
1251
    nvte_tensor_pack_destroy(&aux_input_tensors);
}
1252

1253
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
1254
1255
1256
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
1257
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
1258
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
1259

1260
1261
1262
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272

    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

    auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

    // FP16/BF16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
    auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

1273
    auto q_cu_seqlens_tensor =
1274
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1275
    auto kv_cu_seqlens_tensor =
1276
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1277
1278
1279
1280
1281
1282
1283

    auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;
1284
1285
1286
1287
1288
1289
    nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
                                 s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                                 q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                                 dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
                                 is_training, scaling_factor, dropout_probability, qkv_layout,
                                 bias_type, mask_type, query_workspace_tensor.data(), nullptr);
1290
1291
1292

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
1293
1294
}

1295
1296
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
1297
1298
1299
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1300
    // input buffers from XLA
1301
1302
    void *q = buffers[0];
    void *kv = buffers[1];
1303
1304
1305
1306
    void *bias = buffers[2];
    void *q_cu_seqlens = buffers[3];
    void *kv_cu_seqlens = buffers[4];
    void *seed = buffers[5];
1307

1308
    // output buffers from XLA
1309
1310
1311
    void *output = buffers[6];
    void *softmax_aux = buffers[7];
    void *rng_state = buffers[8];
1312
    void *workspace = buffers[9];
1313

1314
    // tensor sizes
1315
1316
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1317
1318
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1319
    auto attn_heads = descriptor.attn_heads;
1320
    auto num_gqa_groups = descriptor.num_gqa_groups;
1321
    auto bias_heads = descriptor.bias_heads;
1322
    auto head_dim = descriptor.head_dim;
1323
    auto scaling_factor = descriptor.scaling_factor;
1324
1325
1326
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1327

1328
1329
1330
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1331

1332
    // input tensors
1333
    auto dtype = descriptor.dtype;
1334
1335
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
1336
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
1337

1338
1339
1340
    // output tensors
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in FP16/BF16
    auto o_tensor = TensorWrapper(output, q_shape, dtype);
1341
    auto q_cu_seqlens_tensor =
1342
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1343
    auto kv_cu_seqlens_tensor =
1344
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1345

1346
1347
    // prep RNG state
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
1348
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
1349
    auto backend = nvte_get_fused_attn_backend(
1350
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1351
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1352
        head_dim);
1353
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
1354

1355
    // auxiliary tensors (to be propagated to the backward pass later)
1356
1357
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);
1358
1359
    PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                      softmax_aux);
1360

1361
    // cuDNN workspace
1362
1363
    auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                          descriptor.wkspace_dtype);
1364

1365
1366
1367
1368
1369
1370
    nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(),
                                 s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                                 q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                                 rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
                                 descriptor.is_training, scaling_factor, dropout_probability,
                                 qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
1371

1372
1373
    nvte_tensor_pack_destroy(&aux_output_tensors);
}
1374

1375
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
1376
1377
1378
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
1379
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
1380
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
1381

1382
1383
1384
1385
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1386

1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
    // FP16/BF16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

1398
    auto q_cu_seqlens_tensor =
1399
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1400
    auto kv_cu_seqlens_tensor =
1401
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for FP16/BF16
        s_tensor.data(),  // not used for FP16/BF16
        &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
1413
1414
        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
        query_workspace_tensor.data(), nullptr);
1415
1416
1417

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
1418
1419
}

1420
1421
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                            size_t opaque_len) {
1422
1423
1424
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

1425
    // input buffers from XLA
1426
1427
    void *q = buffers[0];
    void *kv = buffers[1];
1428
1429
1430
1431
1432
1433
1434
    void *bias = buffers[2];
    void *softmax_aux = buffers[3];
    void *rng_state = buffers[4];
    void *output = buffers[5];
    void *doutput = buffers[6];
    void *q_cu_seqlens = buffers[7];
    void *kv_cu_seqlens = buffers[8];
1435

1436
    // output buffers from XLA
1437
1438
1439
    void *dq = buffers[9];
    void *dkv = buffers[10];
    void *dbias = buffers[11];
1440
    void *workspace = buffers[12];
1441

1442
    // tensor sizes
1443
1444
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1445
1446
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1447
    auto attn_heads = descriptor.attn_heads;
1448
    auto num_gqa_groups = descriptor.num_gqa_groups;
1449
    auto bias_heads = descriptor.bias_heads;
1450
    auto head_dim = descriptor.head_dim;
1451
    auto scaling_factor = descriptor.scaling_factor;
1452
1453
1454
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;
1455

1456
1457
1458
1459
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1460

1461
1462
    // input tensors
    auto dtype = descriptor.dtype;
1463
1464
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
1465
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
1466
1467
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

1468
1469
    // output tensors
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in FP16/BF16
1470
1471
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
    auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
1472
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
1473
    auto q_cu_seqlens_tensor =
1474
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1475
    auto kv_cu_seqlens_tensor =
1476
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1477
1478
1479
1480
1481
1482

    // auxiliary tensors (propagated from the forward pass)
    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
    auto backend = nvte_get_fused_attn_backend(
1483
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1484
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1485
1486
1487
        head_dim);
    PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                       rng_state, bias);
1488

1489
1490
1491
1492
    // cuDNN workspace
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
1493
1494
1495
1496
1497

    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for FP16/BF16
        s_tensor.data(),  // not used for FP16/BF16
1498
        &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
1499
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
1500
1501
        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
        workspace_tensor.data(), stream);
1502

1503
    nvte_tensor_pack_destroy(&aux_input_tensors);
1504
1505
}

1506
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
1507
1508
1509
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
1510
1511
1512
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;

1513
1514
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1515
    auto v_shape = k_shape;
1516
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528

    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

    auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
    auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

    auto q_cu_seqlens_tensor =
1529
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1530
    auto kv_cu_seqlens_tensor =
1531
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569

    auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
                        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                        query_workspace_tensor.data(), nullptr);

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input buffers from XLA
    void *q = buffers[0];
    void *k = buffers[1];
    void *v = buffers[2];
    void *bias = buffers[3];
    void *q_cu_seqlens = buffers[4];
    void *kv_cu_seqlens = buffers[5];
    void *seed = buffers[6];

    // output buffers from XLA
    void *output = buffers[7];
    void *softmax_aux = buffers[8];
    void *rng_state = buffers[9];
    void *workspace = buffers[10];

    // tensor sizes
1570
1571
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1572
1573
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1574
    auto attn_heads = descriptor.attn_heads;
1575
    auto num_gqa_groups = descriptor.num_gqa_groups;
1576
    auto bias_heads = descriptor.bias_heads;
1577
1578
1579
1580
1581
1582
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;

1583
1584
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1585
    auto v_shape = k_shape;
1586
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598

    // input tensors
    auto dtype = descriptor.dtype;
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
    auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

    // output tensors
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
    auto o_tensor = TensorWrapper(output, q_shape, dtype);
    auto q_cu_seqlens_tensor =
1599
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1600
    auto kv_cu_seqlens_tensor =
1601
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1602
1603
1604
1605
1606
1607

    // prep RNG state
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
    auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1608
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
        head_dim);
    PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

    // auxiliary tensors (to be propagated to the backward pass later)
    NVTETensorPack aux_output_tensors;
    nvte_tensor_pack_create(&aux_output_tensors);
    PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                      softmax_aux);

    // cuDNN workspace
    auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                          descriptor.wkspace_dtype);

    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
                        descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
                        bias_type, mask_type, workspace_tensor.data(), stream);

    nvte_tensor_pack_destroy(&aux_output_tensors);
}

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
1633
1634
1635
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
1636
1637
1638
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;

1639
1640
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1641
    auto v_shape = k_shape;
1642
1643
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658

    auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
    auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
    // F16 doesn't use this tensor
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

    auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
    auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
    auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
    auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

    auto q_cu_seqlens_tensor =
1659
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1660
    auto kv_cu_seqlens_tensor =
1661
        TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704

    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);

    TensorWrapper query_workspace_tensor;
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                        kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
                        dropout_probability, qkv_layout, bias_type, mask_type,
                        query_workspace_tensor.data(), nullptr);

    auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
    return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
    const CustomCallFusedAttnDescriptor &descriptor =
        *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

    // input buffers from XLA
    void *q = buffers[0];
    void *k = buffers[1];
    void *v = buffers[2];
    void *bias = buffers[3];
    void *softmax_aux = buffers[4];
    void *rng_state = buffers[5];
    void *output = buffers[6];
    void *doutput = buffers[7];
    void *q_cu_seqlens = buffers[8];
    void *kv_cu_seqlens = buffers[9];

    // output buffers from XLA
    void *dq = buffers[10];
    void *dk = buffers[11];
    void *dv = buffers[12];
    void *dbias = buffers[13];
    void *workspace = buffers[14];

    // tensor sizes
1705
1706
    auto input_batch = descriptor.input_batch;
    auto bias_batch = descriptor.bias_batch;
1707
1708
    auto q_max_seqlen = descriptor.q_max_seqlen;
    auto kv_max_seqlen = descriptor.kv_max_seqlen;
1709
    auto attn_heads = descriptor.attn_heads;
1710
    auto num_gqa_groups = descriptor.num_gqa_groups;
1711
    auto bias_heads = descriptor.bias_heads;
1712
1713
1714
1715
1716
1717
    auto head_dim = descriptor.head_dim;
    auto scaling_factor = descriptor.scaling_factor;
    auto dropout_probability = descriptor.dropout_probability;
    auto bias_type = descriptor.bias_type;
    auto mask_type = descriptor.mask_type;

1718
1719
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
1720
    auto v_shape = k_shape;
1721
1722
    auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738

    // input tensors
    auto dtype = descriptor.dtype;
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
    auto output_tensor = TensorWrapper(output, output_shape, dtype);
    auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

    // output tensors
    auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
    auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
    auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
    auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
    auto q_cu_seqlens_tensor =
1739
        TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1740
    auto kv_cu_seqlens_tensor =
1741
        TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
1742
1743
1744
1745
1746
1747
1748

    // auxiliary tensors (propagated from the forward pass)
    NVTETensorPack aux_input_tensors;
    nvte_tensor_pack_create(&aux_input_tensors);
    constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD;
    auto backend = nvte_get_fused_attn_backend(
        static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
1749
        mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
        head_dim);
    PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                       rng_state, bias);

    // cuDNN workspace
    auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
    auto wkspace_dtype = descriptor.wkspace_dtype;
    auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);

    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                        kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
                        dropout_probability, qkv_layout, bias_type, mask_type,
                        workspace_tensor.data(), stream);

    nvte_tensor_pack_destroy(&aux_input_tensors);
}

1772
1773
}  // namespace jax
}  // namespace transformer_engine