attention.cpp 31.9 KB
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
9
10
11
12
13
14
15
16
17
18
#include "transformer_engine/fused_attn.h"

namespace transformer_engine {
namespace jax {

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_attn_heads, size_t kv_attn_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim) {
19
20
21
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
22
      head_dim, head_dim, -1, -1);
23
  return backend;
24
25
26
27
28
29
30
31
32
33
34
35
36
}

/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack,
                                       const CustomCallFusedAttnDescriptor *desc,
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  auto input_batch = desc->input_batch;
  auto bias_batch = desc->bias_batch;
  auto attn_heads = desc->attn_heads;
  auto bias_heads = desc->bias_heads;
  auto q_max_seqlen = desc->q_max_seqlen;
  auto kv_max_seqlen = desc->kv_max_seqlen;

  // all backends need softmax but expect different shapes/dtypes
  // start with the max512 sequence length softmax shape/dtype and correct later
  tensor_pack->size = 1;
  Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
  softmax_aux->data.dptr = softmax_buf;
  softmax_aux->data.shape =
      std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
  softmax_aux->data.dtype = desc->dtype;

  // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
    tensor_pack->size = 2;
    Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
    rng_state_aux->data.dptr = rng_state_buf;
    rng_state_aux->data.shape = std::vector<size_t>{2};
    rng_state_aux->data.dtype = DType::kInt64;
    // correct softmax shape/dtype
    softmax_aux->data.shape.at(3) = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
    softmax_aux->data.dtype = DType::kFloat32;

    // include bias if enabled
    if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
      tensor_pack->size = 3;
      Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
      bias_aux->data.dptr = bias_buf;
      bias_aux->data.shape =
          std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
      bias_aux->data.dtype = desc->dtype;
72
    }
73
  }
74
75
76
77
78
79
80
81
82
83
84
85
86
87
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack,
                                        const CustomCallFusedAttnDescriptor *desc,
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
88
89
90
91
92
93
94
95
96
97
98
99
100
  // Backward calls put everything into the tensor pack for every backend
  // so we set dummy bias_type and backend choices here to follow the correct code path
  auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
  auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
  PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf,
                                    rng_state_buf, bias_buf);

  // correct softmax shape for max512 sequence length kernel
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
    Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
    softmax_aux->data.shape.at(3) = desc->kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
    softmax_aux->data.dtype = desc->dtype;
  }
101
102
103
104
105
106
}

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
107
108
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
    size_t max_segments_per_seq) {
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);

  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
  auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

  auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);

  TensorWrapper query_workspace_tensor;
138
139
140
141
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
142
143
144
145
146
147
148
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
      NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
      nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                    o_tensor.data(), &aux_output_tensors,
                                    q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
                                    dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
                                    scaling_factor, dropout_probability, qkv_layout, bias_type,
163
                                    mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
164
165
166
167
168
169
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_fwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
          &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
          ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
          q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
170
          bias_type, mask_type, -1, -1, query_workspace_tensor.data(), nullptr);
171
172
173
174
175
176
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
      nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                          s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                          q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                          ragged_offset_tensor.data(), ragged_offset_tensor.data(),
                          dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
177
178
                          scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
                          -1, query_workspace_tensor.data(), nullptr);
179
180
181
    } else {
      NVTE_ERROR("Unsupported QKVLayout.");
    }
182
183
184
185
  }

  auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
186
187
188
}

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
189
190
191
  const CustomCallFusedAttnDescriptor &descriptor =
      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

192
193
194
  auto qkv_layout = descriptor.qkv_layout;
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

195
196
197
198
199
  /* Input buffers from XLA */
  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
  void *bias = buffers[3];
  void *q_cu_seqlens = buffers[4];
  void *kv_cu_seqlens = buffers[5];
200
201
202
  void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
  void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
  void *seed = buffers[8];
203
204

  /* Output buffer from XLA */
205
206
207
208
  void *output = buffers[9];
  void *softmax_aux = buffers[10];
  void *rng_state = buffers[11];
  void *workspace = buffers[12];
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

  /* Descriptor */
  auto input_batch = descriptor.input_batch;
  auto bias_batch = descriptor.bias_batch;
  auto q_max_seqlen = descriptor.q_max_seqlen;
  auto kv_max_seqlen = descriptor.kv_max_seqlen;
  auto attn_heads = descriptor.attn_heads;
  auto num_gqa_groups = descriptor.num_gqa_groups;
  auto bias_heads = descriptor.bias_heads;
  auto head_dim = descriptor.head_dim;
  auto scaling_factor = descriptor.scaling_factor;
  auto dropout_probability = descriptor.dropout_probability;
  auto bias_type = descriptor.bias_type;
  auto mask_type = descriptor.mask_type;
  auto dtype = descriptor.dtype;
224
225
  auto is_training = descriptor.is_training;
  auto max_segments_per_seq = descriptor.max_segments_per_seq;
226
227
228
229
230
231
232
233

  /* Input tensors */
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto v_shape = k_shape;
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

234
235
  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
  if (is_ragged) {
236
237
238
239
240
241
242
243
244
245
246
247
248
    auto cudnn_runtime_version = cudnnGetVersion();
    if (cudnn_runtime_version >= 90300) {
      num_segments = input_batch * max_segments_per_seq;
    } else {
      // workspace can be reused here as it is not used with cuDNN graph at the same time
      size_t runtime_num_segments_q =
          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
      size_t runtime_num_segments_kv =
          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
      num_segments = runtime_num_segments_q;
    }
249
250
251
252
253
254
255
256
257
258
259
260
261
    cudaMemsetAsync(output, 0,
                    input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
  }

  auto q_cu_seqlens_tensor =
      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto kv_cu_seqlens_tensor =
      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto q_seq_offsets_tensor =
      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto k_seq_offsets_tensor =
      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

262
263
264
265
266
267
268
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto o_tensor = TensorWrapper(output, o_shape, dtype);

  /* Prepare RNG state */
  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
269
270
271
272
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
      head_dim, head_dim, -1, -1);
273
274
275
276
277
278
279
280
281
282
283
284
285
  PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

  /* Auxiliary tensors (to be propagated to the backward pass later) */
  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);
  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend,
                                    softmax_aux);

  /* cuDNN workspace */
  auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
                                        descriptor.wkspace_dtype);

  /* Call the underly NVTE API */
286
287
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
288
289
290
    auto qkv = buffers[0];
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
291
292
293
294
295
296
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
                                  q_seq_offsets_tensor.data(), rng_state_tensor.data(),
                                  q_max_seqlen, is_training, descriptor.scaling_factor,
                                  dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
                                  workspace_tensor.data(), stream);
297
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
298
    auto q = buffers[0];
299
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
300
301
302
303
304
305
306
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv = buffers[1];
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
307
308
        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
        q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
309
        bias_type, mask_type, -1, -1, workspace_tensor.data(), stream);
310
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
311
312
313
314
    auto q = buffers[0];
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k = buffers[1];
315
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
316
317
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v = buffers[2];
318
    auto v_shape = k_shape;
319
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
320
321
322
323
324
325
326
    nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
                        s_tensor.data(), o_tensor.data(), &aux_output_tensors,
                        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
                        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
                        rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
                        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
                        -1, workspace_tensor.data(), stream);
327
328
329
330
331
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_output_tensors);
332
333
334
335
336
337
}

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
338
339
340
341
342
343
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
    size_t max_segments_per_seq) {
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
  auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
344

345
346
  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
347
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
348
349
350
351
352
353
354
  auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
  auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
355
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
356
357
  auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
358
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
359
360
361
  auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
362
363
  auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
364

365
366
367
  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

368
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
369
370
371
372
373
374
  auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);

  TensorWrapper query_workspace_tensor;
375
376
377
378
379

  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
380
381
382
383
384
385
386
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
387
388
389
390
391
392
393
394
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
395
396
397
398
399
400
401
      nvte_fused_attn_bwd_qkvpacked(
          qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
          s_tensor.data(),  // not used for F16
          s_tensor.data(),  // not used for F16
          &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
          dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
          qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
402
403
404
405
406
407
408
409
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_bwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
          s_tensor.data(),  // not used for F16
          s_tensor.data(),  // not used for F16
          &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
          q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
          dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
410
411
          kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
          -1, true, query_workspace_tensor.data(), nullptr);
412
413
414
415
416
417
418
419
420
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
      nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                          doutput_tensor.data(),
                          s_tensor.data(),  // not used for F16
                          s_tensor.data(),  // not used for F16
                          &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                          dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                          kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
                          dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
421
422
                          scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
                          -1, true, query_workspace_tensor.data(), nullptr);
423
424
425
426
    } else {
      NVTE_ERROR("Unsupported qkv_layout.");
    }
  }
427
428
429
430
431
432
433
434
435

  auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
  const CustomCallFusedAttnDescriptor &descriptor =
      *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);

436
437
438
  auto qkv_layout = descriptor.qkv_layout;
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

439
440
441
442
443
444
445
446
447
  /* Input buffers from XLA */
  /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
  void *bias = buffers[3];
  void *softmax_aux = buffers[4];
  void *rng_state = buffers[5];
  void *output = buffers[6];
  void *doutput = buffers[7];
  void *q_cu_seqlens = buffers[8];
  void *kv_cu_seqlens = buffers[9];
448
449
  void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
  void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
450
451

  /* Output buffer from XLA */
452
453
454
  /* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */
  void *dbias = buffers[15];
  void *workspace = buffers[16];
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

  /* Descriptor */
  auto input_batch = descriptor.input_batch;
  auto bias_batch = descriptor.bias_batch;
  auto q_max_seqlen = descriptor.q_max_seqlen;
  auto kv_max_seqlen = descriptor.kv_max_seqlen;
  auto attn_heads = descriptor.attn_heads;
  auto num_gqa_groups = descriptor.num_gqa_groups;
  auto bias_heads = descriptor.bias_heads;
  auto head_dim = descriptor.head_dim;
  auto scaling_factor = descriptor.scaling_factor;
  auto dropout_probability = descriptor.dropout_probability;
  auto bias_type = descriptor.bias_type;
  auto mask_type = descriptor.mask_type;
  auto dtype = descriptor.dtype;
470
  auto max_segments_per_seq = descriptor.max_segments_per_seq;
471
472
473
474
475
476
477

  /* Input tensors */
  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto output_tensor = TensorWrapper(output, output_shape, dtype);
  auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

478
479
  size_t num_segments = input_batch;  // Non-THD format, input_batch = num_segments
  if (is_ragged) {
480
481
482
483
484
485
486
487
488
489
490
491
492
    auto cudnn_runtime_version = cudnnGetVersion();
    if (cudnn_runtime_version >= 90300) {
      num_segments = input_batch * max_segments_per_seq;
    } else {
      // workspace can be reused here as it is not used with cuDNN graph at the same time
      size_t runtime_num_segments_q =
          GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
      size_t runtime_num_segments_kv =
          GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
      num_segments = runtime_num_segments_q;
    }
493
494
495
496
497
498
499
500
501
502
503
  }

  auto q_cu_seqlens_tensor =
      TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto kv_cu_seqlens_tensor =
      TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto q_seq_offsets_tensor =
      TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
  auto k_seq_offsets_tensor =
      TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);

504
505
506
507
508
509
510
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);

  /* Auxiliary tensors (propagated from the forward pass) */
  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);
511
512
513
514
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
      head_dim, head_dim, -1, -1);
515
516
517
518
519
520
521
522
523
  PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux,
                                     rng_state, bias);

  /* cuDNN workspace */
  auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
  auto wkspace_dtype = descriptor.wkspace_dtype;
  auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);

  /* Call the underly NVTE API */
524
525
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
526
527
528
    auto qkv = buffers[0];
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
    auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
529
    auto dqkv = buffers[12];
530
    auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
531
532
533
534
535
    if (is_ragged) {
      size_t dqkv_size =
          std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
      cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
    }
536
537
538
539
540
541
542
    nvte_fused_attn_bwd_qkvpacked(
        qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for F16
        s_tensor.data(),  // not used for F16
        &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
        q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
        bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
543
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
544
545
546
547
548
549
    auto q = buffers[0];
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv = buffers[1];
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
    auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
550
    auto dq = buffers[12];
551
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
552
    auto dkv = buffers[13];
553
    auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
554
555
556
557
558
559
560
561
    if (is_ragged) {
      size_t dq_size =
          std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
      size_t dkv_size =
          std::accumulate(kv_shape.cbegin(), kv_shape.cend(), 1, std::multiplies<size_t>());
      cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
      cudaMemsetAsync(dkv, 0, dkv_size * typeToSize(dtype), stream);
    }
562
563
564
565
566
    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for F16
        s_tensor.data(),  // not used for F16
        &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
567
568
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
569
570
        dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
        workspace_tensor.data(), stream);
571
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
572
    auto q = buffers[0];
573
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
574
575
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k = buffers[1];
576
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
577
578
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
    auto v = buffers[2];
579
    auto v_shape = k_shape;
580
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
581
    auto dq = buffers[12];
582
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
583
    auto dk = buffers[13];
584
    auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
585
    auto dv = buffers[14];
586
    auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
587
588
589
590
591
592
593
594
595
596
    if (is_ragged) {
      size_t dq_size =
          std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
      size_t dk_size =
          std::accumulate(k_shape.cbegin(), k_shape.cend(), 1, std::multiplies<size_t>());
      size_t dv_size = dk_size;
      cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
      cudaMemsetAsync(dk, 0, dk_size * typeToSize(dtype), stream);
      cudaMemsetAsync(dv, 0, dv_size * typeToSize(dtype), stream);
    }
597
598
599
600
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
601
602
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
603
604
                        kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
                        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
605
                        dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
606
607
608
609
610
611
                        workspace_tensor.data(), stream);
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_input_tensors);
612
613
614
615
}

}  // namespace jax
}  // namespace transformer_engine