attention.cpp 44.2 KB
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
9
10
11
12
13
14
15
16
17
#include "transformer_engine/fused_attn.h"

namespace transformer_engine {
namespace jax {

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_attn_heads, size_t kv_attn_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
18
19
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right) {
20
21
22
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
23
      head_dim, head_dim, window_size_left, window_size_right);
24
  return backend;
25
26
27
28
29
30
31
32
}

/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
33
34
35
36
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                       const size_t bias_batch, const size_t attn_heads,
                                       const size_t bias_heads, const size_t q_max_seqlen,
                                       const size_t kv_max_seqlen, DType dtype,
37
38
39
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
40
41
42
43
44
45
46
  // all backends need softmax but expect different shapes/dtypes
  // start with the max512 sequence length softmax shape/dtype and correct later
  tensor_pack->size = 1;
  Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
  softmax_aux->data.dptr = softmax_buf;
  softmax_aux->data.shape =
      std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
47
  softmax_aux->data.dtype = dtype;
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

  // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
    tensor_pack->size = 2;
    Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
    rng_state_aux->data.dptr = rng_state_buf;
    rng_state_aux->data.shape = std::vector<size_t>{2};
    rng_state_aux->data.dtype = DType::kInt64;
    // correct softmax shape/dtype
    softmax_aux->data.shape.at(3) = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
    softmax_aux->data.dtype = DType::kFloat32;

    // include bias if enabled
    if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
      tensor_pack->size = 3;
      Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
      bias_aux->data.dptr = bias_buf;
      bias_aux->data.shape =
          std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
67
      bias_aux->data.dtype = dtype;
68
    }
69
  }
70
71
72
73
74
75
76
77
78
79
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
80
81
82
83
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                        const size_t bias_batch, const size_t attn_heads,
                                        const size_t bias_heads, const size_t q_max_seqlen,
                                        const size_t kv_max_seqlen, DType dtype,
84
85
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
86
87
88
89
  // Backward calls put everything into the tensor pack for every backend
  // so we set dummy bias_type and backend choices here to follow the correct code path
  auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
  auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
90
91
92
  PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
                                    q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
                                    dummy_backend, softmax_buf, rng_state_buf, bias_buf);
93
94
95
96

  // correct softmax shape for max512 sequence length kernel
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
    Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
97
98
    softmax_aux->data.shape.at(3) = kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
    softmax_aux->data.dtype = dtype;
99
  }
100
101
102
103
104
105
}

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
106
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
107
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);

  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
  auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

  auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);

  TensorWrapper query_workspace_tensor;
137
138
139
140
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
141
142
143
144
145
146
147
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
148
149
150
151
152
153
154
155
156
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
      NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
157
158
159
160
161
162
      nvte_fused_attn_fwd_qkvpacked(
          qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
          &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
          dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor,
          dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
          window_size_right, query_workspace_tensor.data(), nullptr);
163
164
165
166
167
168
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_fwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
          &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
          ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
          q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
169
170
          bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(),
          nullptr);
171
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
172
173
174
175
176
177
178
      nvte_fused_attn_fwd(
          q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
          o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
          kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
          dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
          dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
          window_size_right, query_workspace_tensor.data(), nullptr);
179
180
181
    } else {
      NVTE_ERROR("Unsupported QKVLayout.");
    }
182
183
184
185
  }

  auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
186
187
188
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
189
190
  const CustomCallFusedAttnDescriptor &descriptor =
      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
191
192
193
  auto qkv_layout = descriptor.qkv_layout;
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

194
195
196
197
198
  /* Input buffers from XLA */
  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
  void *bias = buffers[3];
  void *q_cu_seqlens = buffers[4];
  void *kv_cu_seqlens = buffers[5];
199
200
201
  void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
  void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
  void *seed = buffers[8];
202
203

  /* Output buffer from XLA */
204
205
206
207
  void *output = buffers[9];
  void *softmax_aux = buffers[10];
  void *rng_state = buffers[11];
  void *workspace = buffers[12];
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

  /* Descriptor */
  auto input_batch = descriptor.input_batch;
  auto bias_batch = descriptor.bias_batch;
  auto q_max_seqlen = descriptor.q_max_seqlen;
  auto kv_max_seqlen = descriptor.kv_max_seqlen;
  auto attn_heads = descriptor.attn_heads;
  auto num_gqa_groups = descriptor.num_gqa_groups;
  auto bias_heads = descriptor.bias_heads;
  auto head_dim = descriptor.head_dim;
  auto scaling_factor = descriptor.scaling_factor;
  auto dropout_probability = descriptor.dropout_probability;
  auto bias_type = descriptor.bias_type;
  auto mask_type = descriptor.mask_type;
  auto dtype = descriptor.dtype;
223
224
  auto is_training = descriptor.is_training;
  auto max_segments_per_seq = descriptor.max_segments_per_seq;
225
226
  auto window_size_left = descriptor.window_size_left;
  auto window_size_right = descriptor.window_size_right;
227
228
229
230
231
232
233
234

  /* Input tensors */
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto v_shape = k_shape;
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

235
236
  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
  if (is_ragged) {
237
238
239
240
241
242
243
244
245
246
247
248
249
    auto cudnn_runtime_version = cudnnGetVersion();
    if (cudnn_runtime_version >= 90300) {
      num_segments = input_batch * max_segments_per_seq;
    } else {
      // workspace can be reused here as it is not used with cuDNN graph at the same time
      size_t runtime_num_segments_q =
          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
      size_t runtime_num_segments_kv =
          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
      num_segments = runtime_num_segments_q;
    }
250
251
252
253
254
255
256
257
258
259
260
261
262
    cudaMemsetAsync(output, 0,
                    input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
  }

  auto q_cu_seqlens_tensor =
      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto kv_cu_seqlens_tensor =
      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto q_seq_offsets_tensor =
      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto k_seq_offsets_tensor =
      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

263
264
265
266
267
268
269
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto o_tensor = TensorWrapper(output, o_shape, dtype);

  /* Prepare RNG state */
  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
270
271
272
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
273
      head_dim, head_dim, window_size_left, window_size_right);
274
275
276
277
278
  PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

  /* Auxiliary tensors (to be propagated to the backward pass later) */
  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);
279
280
281
  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
                                    bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
                                    backend, softmax_aux);
282
283
284
285
286
287

  /* cuDNN workspace */
  auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                        descriptor.wkspace_dtype);

  /* Call the underly NVTE API */
288
289
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
290
291
292
    auto qkv = buffers[0];
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
293
294
295
296
297
298
    nvte_fused_attn_fwd_qkvpacked(
        qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
        rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
        workspace_tensor.data(), stream);
299
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
300
    auto q = buffers[0];
301
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
302
303
304
305
306
307
308
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv = buffers[1];
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
309
310
        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
        q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
311
        bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
312
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
313
314
315
316
    auto q = buffers[0];
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k = buffers[1];
317
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
318
319
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v = buffers[2];
320
    auto v_shape = k_shape;
321
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
322
323
324
325
326
    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
                        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
327
328
                        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                        window_size_left, window_size_right, workspace_tensor.data(), stream);
329
330
331
332
333
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_output_tensors);
334
335
}

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
Error_Type FusedAttnForwardFFI(
    cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
    Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
    Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, Buffer_Type seed_buf,
    Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
    Result_Type workspace_buf, int64_t input_batch_, int64_t bias_batch_, int64_t q_max_seqlen_,
    int64_t kv_max_seqlen_, int64_t attn_heads_, int64_t num_gqa_groups_, int64_t bias_heads_,
    int64_t head_dim_, int64_t max_segments_per_seq_, int64_t wkspace_size_, double scaling_factor_,
    double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
    int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
    int64_t window_size_left, int64_t window_size_right) {
  /* Descriptor data type conversion */
  size_t input_batch = static_cast<size_t>(input_batch_);
  size_t bias_batch = static_cast<size_t>(bias_batch_);
  size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
  size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
  size_t attn_heads = static_cast<size_t>(attn_heads_);
  size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
  size_t bias_heads = static_cast<size_t>(bias_heads_);
  size_t head_dim = static_cast<size_t>(head_dim_);
  size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
  size_t wkspace_size = static_cast<size_t>(wkspace_size_);
  float scaling_factor = static_cast<float>(scaling_factor_);
  float dropout_probability = static_cast<float>(dropout_probability_);
  NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
  NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
  NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
  DType dtype = static_cast<DType>(dtype_);
  DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

  /* Input buffers from XLA */
  /* q, k, v are parsed later for different qkv_layout */
  void *bias = bias_buf.untyped_data();
  void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data();
  void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data();
  void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr;
  void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr;
  void *seed = seed_buf.untyped_data();

  /* Output buffer from XLA */
  void *output = output_buf->untyped_data();
  void *softmax_aux = softmax_aux_buf->untyped_data();
  void *rng_state = rng_state_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  /* Input tensors */
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto v_shape = k_shape;
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
  if (is_ragged) {
    auto cudnn_runtime_version = cudnnGetVersion();
    if (cudnn_runtime_version >= 90300) {
      num_segments = input_batch * max_segments_per_seq;
    } else {
      // workspace can be reused here as it is not used with cuDNN graph at the same time
      size_t runtime_num_segments_q =
          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
      size_t runtime_num_segments_kv =
          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
      num_segments = runtime_num_segments_q;
    }
    auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
    cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
  }

  auto q_cu_seqlens_tensor =
      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto kv_cu_seqlens_tensor =
      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto q_seq_offsets_tensor =
      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto k_seq_offsets_tensor =
      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto o_tensor = TensorWrapper(output, o_shape, dtype);

  /* Prepare RNG state */
  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
      head_dim, head_dim, window_size_left, window_size_right);
  PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

  /* Auxiliary tensors (to be propagated to the backward pass later) */
  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);
  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
                                    bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
                                    backend, softmax_aux);

  /* cuDNN workspace */
  auto workspace_tensor =
      TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);

  /* Call the underlying NVTE API */
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
    auto qkv = q_buf.untyped_data();
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
                                  q_seq_offsets_tensor.data(), rng_state_tensor.data(),
                                  q_max_seqlen, is_training, scaling_factor, dropout_probability,
                                  qkv_layout, bias_type, mask_type, window_size_left,
                                  window_size_right, workspace_tensor.data(), stream);
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
    auto q = q_buf.untyped_data();
    auto kv = k_buf.untyped_data();
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
        q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
        bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
    auto q = q_buf.untyped_data();
    auto k = k_buf.untyped_data();
    auto v = v_buf.untyped_data();
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
    auto v_shape = k_shape;
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
                        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
                        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                        window_size_left, window_size_right, workspace_tensor.data(), stream);
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_output_tensors);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // q
                                  .Arg<Buffer_Type>()      // k
                                  .Arg<Buffer_Type>()      // v
                                  .Arg<Buffer_Type>()      // bias
                                  .Arg<Buffer_Type>()      // q_cu_seqlens
                                  .Arg<Buffer_Type>()      // kv_cu_seqlens
                                  .Arg<Buffer_Type>()      // q_seq_offsets
                                  .Arg<Buffer_Type>()      // k_seq_offsets
                                  .Arg<Buffer_Type>()      // seed_buf
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // softmax_aux
                                  .Ret<Buffer_Type>()      // rng_state
                                  .Ret<Buffer_Type>()      // workspace
                                  .Attr<int64_t>("input_batch")
                                  .Attr<int64_t>("bias_batch")
                                  .Attr<int64_t>("q_max_seqlen")
                                  .Attr<int64_t>("kv_max_seqlen")
                                  .Attr<int64_t>("attn_heads")
                                  .Attr<int64_t>("num_gqa_groups")
                                  .Attr<int64_t>("bias_heads")
                                  .Attr<int64_t>("head_dim")
                                  .Attr<int64_t>("max_segments_per_seq")
                                  .Attr<int64_t>("wkspace_size")
                                  .Attr<double>("scaling_factor")
                                  .Attr<double>("dropout_probability")
                                  .Attr<int64_t>("bias_type")
                                  .Attr<int64_t>("mask_type")
                                  .Attr<int64_t>("qkv_layout")
                                  .Attr<int64_t>("dtype")
                                  .Attr<int64_t>("wkspace_dtype")
                                  .Attr<bool>("is_training")
                                  .Attr<bool>("deterministic")
                                  .Attr<int64_t>("window_size_left")
                                  .Attr<int64_t>("window_size_right"),
                              FFI_CudaGraph_Traits);

531
532
533
534
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
535
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
536
537
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right) {
538
539
540
541
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
  auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
542

543
544
  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
545
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
546
547
548
549
550
551
552
  auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
  auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
553
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
554
555
  auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
556
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
557
558
559
  auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
560
561
  auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
562

563
564
565
  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

566
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
567
568
569
570
571
572
  auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);

  TensorWrapper query_workspace_tensor;
573
574
575
576
577

  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
578
579
580
581
582
583
584
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
585
586
587
588
589
590
591
592
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
593
594
595
596
597
598
      nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
                                    s_tensor.data(),  // not used for F16
                                    s_tensor.data(),  // not used for F16
                                    &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                    q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
                                    q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
599
600
                                    bias_type, mask_type, window_size_left, window_size_right,
                                    deterministic, query_workspace_tensor.data(), nullptr);
601
602
603
604
605
606
607
608
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_bwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
          s_tensor.data(),  // not used for F16
          s_tensor.data(),  // not used for F16
          &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
          q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
          dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
609
610
611
          kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
          window_size_left, window_size_right, deterministic, query_workspace_tensor.data(),
          nullptr);
612
613
614
615
616
617
618
619
620
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
      nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                          doutput_tensor.data(),
                          s_tensor.data(),  // not used for F16
                          s_tensor.data(),  // not used for F16
                          &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                          dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                          kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
                          dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
621
622
623
                          scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                          window_size_left, window_size_right, deterministic,
                          query_workspace_tensor.data(), nullptr);
624
625
626
627
    } else {
      NVTE_ERROR("Unsupported qkv_layout.");
    }
  }
628
629
630
631
632
633
634
635
636

  auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
  const CustomCallFusedAttnDescriptor &descriptor =
      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

637
638
639
  auto qkv_layout = descriptor.qkv_layout;
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

640
641
642
643
644
645
646
647
648
  /* Input buffers from XLA */
  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
  void *bias = buffers[3];
  void *softmax_aux = buffers[4];
  void *rng_state = buffers[5];
  void *output = buffers[6];
  void *doutput = buffers[7];
  void *q_cu_seqlens = buffers[8];
  void *kv_cu_seqlens = buffers[9];
649
650
  void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
  void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
651
652

  /* Output buffer from XLA */
653
654
655
  /* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */
  void *dbias = buffers[15];
  void *workspace = buffers[16];
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

  /* Descriptor */
  auto input_batch = descriptor.input_batch;
  auto bias_batch = descriptor.bias_batch;
  auto q_max_seqlen = descriptor.q_max_seqlen;
  auto kv_max_seqlen = descriptor.kv_max_seqlen;
  auto attn_heads = descriptor.attn_heads;
  auto num_gqa_groups = descriptor.num_gqa_groups;
  auto bias_heads = descriptor.bias_heads;
  auto head_dim = descriptor.head_dim;
  auto scaling_factor = descriptor.scaling_factor;
  auto dropout_probability = descriptor.dropout_probability;
  auto bias_type = descriptor.bias_type;
  auto mask_type = descriptor.mask_type;
  auto dtype = descriptor.dtype;
671
  auto deterministic = descriptor.deterministic;
672
  auto max_segments_per_seq = descriptor.max_segments_per_seq;
673
674
  auto window_size_left = descriptor.window_size_left;
  auto window_size_right = descriptor.window_size_right;
675
676
677
678
679
680
681

  /* Input tensors */
  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto output_tensor = TensorWrapper(output, output_shape, dtype);
  auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

682
683
  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
  if (is_ragged) {
684
685
686
687
688
689
690
691
692
693
694
695
696
    auto cudnn_runtime_version = cudnnGetVersion();
    if (cudnn_runtime_version >= 90300) {
      num_segments = input_batch * max_segments_per_seq;
    } else {
      // workspace can be reused here as it is not used with cuDNN graph at the same time
      size_t runtime_num_segments_q =
          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
      size_t runtime_num_segments_kv =
          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
      num_segments = runtime_num_segments_q;
    }
697
698
699
700
701
702
703
704
705
706
707
  }

  auto q_cu_seqlens_tensor =
      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto kv_cu_seqlens_tensor =
      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto q_seq_offsets_tensor =
      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto k_seq_offsets_tensor =
      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

708
709
710
711
712
713
714
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);

  /* Auxiliary tensors (propagated from the forward pass) */
  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);
715
716
717
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
718
      head_dim, head_dim, window_size_left, window_size_right);
719
720
721
  PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
                                     bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
                                     softmax_aux, rng_state, bias);
722
723
724
725
726
727
728

  /* cuDNN workspace */
  auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
  auto wkspace_dtype = descriptor.wkspace_dtype;
  auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);

  /* Call the underly NVTE API */
729
730
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
731
732
733
    auto qkv = buffers[0];
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
734
    auto dqkv = buffers[12];
735
    auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
736
    if (is_ragged) {
737
      cudaMemsetAsync(dqkv, 0, product(qkv_shape) * typeToSize(dtype), stream);
738
    }
739
740
741
742
743
744
745
746
    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
                                  &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
                                  q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
                                  bias_type, mask_type, window_size_left, window_size_right,
                                  deterministic, workspace_tensor.data(), stream);
747
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
748
749
750
751
752
753
    auto q = buffers[0];
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv = buffers[1];
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
754
    auto dq = buffers[12];
755
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
756
    auto dkv = buffers[13];
757
    auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
758
    if (is_ragged) {
759
760
      cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dkv, 0, product(kv_shape) * typeToSize(dtype), stream);
761
    }
762
763
764
765
766
    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for F16
        s_tensor.data(),  // not used for F16
        &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
767
768
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
769
770
        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
        deterministic, workspace_tensor.data(), stream);
771
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
772
    auto q = buffers[0];
773
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
774
775
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k = buffers[1];
776
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
777
778
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v = buffers[2];
779
    auto v_shape = k_shape;
780
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
781
    auto dq = buffers[12];
782
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
783
    auto dk = buffers[13];
784
    auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
785
    auto dv = buffers[14];
786
    auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
787
    if (is_ragged) {
788
789
790
      cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dk, 0, product(k_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dv, 0, product(v_shape) * typeToSize(dtype), stream);
791
    }
792
793
794
795
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
796
797
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
798
799
                        kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
                        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
800
801
                        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
                        window_size_right, deterministic, workspace_tensor.data(), stream);
802
803
804
805
806
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_input_tensors);
807
808
809
810
}

}  // namespace jax
}  // namespace transformer_engine