attention.cpp 36.4 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "../extensions.h"
8
#include "transformer_engine/fused_attn.h"
9
#include "transformer_engine/transformer_engine.h"
10
11
12
13

namespace transformer_engine {
namespace jax {

14
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
15
16
17
18
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_attn_heads, size_t kv_attn_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
19
20
                                            size_t qk_head_dim, size_t v_head_dim,
                                            int64_t window_size_left, int64_t window_size_right) {
21
  NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
22
  auto backend = nvte_get_fused_attn_backend(
23
      is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
24
      bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
25
      q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
26
      false, false);
27
  return backend;
28
29
30
31
32
33
34
35
}

/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
36
37
38
39
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                       const size_t bias_batch, const size_t attn_heads,
                                       const size_t bias_heads, const size_t q_max_seqlen,
                                       const size_t kv_max_seqlen, DType dtype,
40
41
42
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
43
44
45
  // all backends need softmax but expect different shapes/dtypes
  // start with the max512 sequence length softmax shape/dtype and correct later
  tensor_pack->size = 1;
46
47
48
49
50
51
52
53
54
  NVTETensor &softmax_aux = tensor_pack->tensors[0];
  NVTEBasicTensor softmax_aux_data;
  softmax_aux_data.data_ptr = softmax_buf;
  softmax_aux_data.shape.ndim = 4;
  softmax_aux_data.shape.data[0] = input_batch;
  softmax_aux_data.shape.data[1] = attn_heads;
  softmax_aux_data.shape.data[2] = q_max_seqlen;
  softmax_aux_data.shape.data[3] = kv_max_seqlen;
  softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
55
56
57
58

  // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
    tensor_pack->size = 2;
59
60
61
62
63
64
65
    NVTETensor &rng_state_aux = tensor_pack->tensors[1];
    NVTEBasicTensor rng_state_aux_data;
    rng_state_aux_data.data_ptr = rng_state_buf;
    rng_state_aux_data.shape = {};
    rng_state_aux_data.shape.ndim = 2;
    rng_state_aux_data.dtype = static_cast<NVTEDType>(DType::kInt64);
    nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data);
66
    // correct softmax shape/dtype
67
68
    softmax_aux_data.shape.data[3] = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
    softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
69
70
71
72

    // include bias if enabled
    if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
      tensor_pack->size = 3;
73
74
75
76
77
78
79
80
81
82
      NVTETensor &bias_aux = tensor_pack->tensors[2];
      NVTEBasicTensor bias_aux_data;
      bias_aux_data.data_ptr = bias_buf;
      bias_aux_data.shape.ndim = 4;
      bias_aux_data.shape.data[0] = bias_batch;
      bias_aux_data.shape.data[1] = bias_heads;
      bias_aux_data.shape.data[2] = q_max_seqlen;
      bias_aux_data.shape.data[3] = kv_max_seqlen;
      bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
      nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
83
    }
84
  }
85
  nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
86
87
88
89
90
91
92
93
94
95
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
96
97
98
99
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                        const size_t bias_batch, const size_t attn_heads,
                                        const size_t bias_heads, const size_t q_max_seqlen,
                                        const size_t kv_max_seqlen, DType dtype,
100
101
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
102
103
104
105
  // Backward calls put everything into the tensor pack for every backend
  // so we set dummy bias_type and backend choices here to follow the correct code path
  auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
  auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
106
107
108
  PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
                                    q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
                                    dummy_backend, softmax_buf, rng_state_buf, bias_buf);
109
110
111

  // correct softmax shape for max512 sequence length kernel
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
112
113
114
115
116
    NVTEBasicTensor softmax_aux_data =
        nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData);
    softmax_aux_data.shape.data[3] = kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
    softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
    nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data);
117
  }
118
119
120
121
}

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
122
123
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
    size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
124
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
125
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
126
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
127
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
128
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
129
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
130
  auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
131
132
133
134
135
136
137
138
139
140
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
  auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

  auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
141
  auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
142
143
144
  auto dummy_softmax_offset_tensor =
      TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
  NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
145
146
147
148
149

  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);

  TensorWrapper query_workspace_tensor;
150
151
152
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
153
154
155
156
157
158
159
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
160
161
162
163
164
165
166
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
167
168
169
170
171
172
173
174
    nvte_fused_attn_fwd(
        q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
        dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
        ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
        dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
        scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
        window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
175
176
  }

177
178
  nvte_tensor_pack_destroy(&aux_output_tensors);

179
180
  auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
181
182
}

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#define FUSED_ATTN_IMPL_COMMON_BLOCK                                                          \
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;              \
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \
  size_t num_segments = input_batch;                                                          \
  if (is_ragged) {                                                                            \
    auto cudnn_runtime_version = cudnnGetVersion();                                           \
    if (cudnn_runtime_version >= 90300) {                                                     \
      num_segments = input_batch * max_segments_per_seq;                                      \
    } else {                                                                                  \
      size_t runtime_num_segments_q = nvte_get_runtime_num_segments(                          \
          q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);                       \
      size_t runtime_num_segments_kv = nvte_get_runtime_num_segments(                         \
          kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);                     \
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);                          \
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);               \
      num_segments = runtime_num_segments_q;                                                  \
    }                                                                                         \
  }                                                                                           \
  std::vector<size_t> seq_shape{num_segments + 1};                                            \
  auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32);           \
  auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32);         \
  auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32);         \
  auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32);         \
  auto workspace_tensor =                                                                     \
      TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);             \
208
209
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);

210
static void FusedAttnForwardImpl(
211
212
213
214
    cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens,
    void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
    void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
    size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
215
216
217
218
    size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
219
  FUSED_ATTN_IMPL_COMMON_BLOCK;
220

221
222
223
  /* Input tensors */
  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

224
  if (is_ragged) {
225
    auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
226
    cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
Reese Wang's avatar
Reese Wang committed
227
228
229
230

    // Memset to 0xF0 for filling large negative numbers
    auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads;
    cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream);
231
232
  }

233
234
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
235
  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
236
237
238
239
  auto o_tensor = TensorWrapper(output, o_shape, dtype);

  /* Prepare RNG state */
  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
240
241
242
243
244

  auto dummy_softmax_offset_tensor =
      TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
  NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;

245
  auto backend = nvte_get_fused_attn_backend(
246
      is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
247
      bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
248
      q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
249
      false, false);
250
  nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
251
252
253
254

  /* Auxiliary tensors (to be propagated to the backward pass later) */
  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);
255
256
257
  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
                                    bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
                                    backend, softmax_aux);
258

259
  /* Call the underlying NVTE API */
260
  auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
261
262
263
264
265
266
267
268
269
270

  // Prepare Q, K, V pointers and shapes based on layout
  // Python passes dummy tensors for unused slots, so we extract from the actual packed data
  void *q_ptr = q;
  void *k_ptr = k;
  void *v_ptr = v;
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
  auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};

271
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
272
273
274
275
276
277
278
279
280
281
282
283
284
    // QKV packed in q: [batch*seqlen, 3, heads, dim]
    // Python passes: q=packed_qkv, k=dummy, v=dummy
    // Extract K and V pointers from the packed q data
    NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
    NVTE_CHECK(qk_head_dim == v_head_dim,
               "For QKV packed layout, qk_head_dim must equal v_head_dim");
    size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
    q_ptr = q;
    k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
    v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
    // For packed QKV, all have same shape since they're views into the same packed tensor
    k_shape = q_shape;
    v_shape = q_shape;
285
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
286
287
288
289
290
291
292
293
294
295
296
    // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
    // Python passes: q=query, k=packed_kv, v=dummy
    // Extract V pointer from the packed k data
    NVTE_CHECK(qk_head_dim == v_head_dim,
               "For KV packed layout, qk_head_dim must equal v_head_dim");
    size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
    q_ptr = q;
    k_ptr = k;
    v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
    // V has same shape as K since they're packed together
    v_shape = k_shape;
297
  }
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  // else NVTE_HD_HD_HD: pointers and shapes already correct

  auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
  auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
  auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);

  nvte_fused_attn_fwd(
      q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
      dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
      q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
      k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
      rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
      scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
      window_size_left, window_size_right, workspace_tensor.data(), stream);
312
313

  nvte_tensor_pack_destroy(&aux_output_tensors);
314
315
}

316
317
318
319
320
321
322
323
#define FUSED_ATTN_FFI_GET_ATTRS                                                        \
  size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch");                   \
  size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch");                     \
  size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen");                 \
  size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen");               \
  size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads");                     \
  size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups");             \
  size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads");                     \
324
325
  size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim");                   \
  size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim");                     \
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
  size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
  auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left");           \
  auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right");         \
  float scaling_factor = get_attr_value<double>(attrs, "scaling_factor");               \
  float dropout_probability = get_attr_value<double>(attrs, "dropout_probability");     \
  NVTE_Bias_Type bias_type =                                                            \
      static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type"));         \
  NVTE_Mask_Type mask_type =                                                            \
      static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type"));         \
  NVTE_QKV_Layout qkv_layout =                                                          \
      static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout"));       \
  bool is_training = get_attr_value<bool>(attrs, "is_training");                        \
  bool deterministic = get_attr_value<bool>(attrs, "deterministic");                    \
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;        \
  size_t wkspace_size = product(workspace_buf->dimensions());                           \
  DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());                 \
  DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

344
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
345
                               Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf,
346
347
                               Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
                               Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
348
                               Variadic_Buffer_Type _unused_args, Result_Type output_buf,
349
350
                               Result_Type softmax_aux_buf, Result_Type rng_state_buf,
                               Result_Type workspace_buf, Dictionary attrs) {
351
  FUSED_ATTN_FFI_GET_ATTRS;
352

353
354
  FusedAttnForwardImpl(
      stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
355
356
357
358
359
      bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
      kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
      is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
      softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
      input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
360
361
362
      qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
      dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training,
      deterministic, window_size_left, window_size_right);
363
364
365
366
367
368
369
370
371
372
373

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // q
                                  .Arg<Buffer_Type>()      // k
                                  .Arg<Buffer_Type>()      // v
                                  .Arg<Buffer_Type>()      // bias
374
                                  .Arg<Buffer_Type>()      // seed_buf
375
376
377
378
                                  .Arg<Buffer_Type>()      // q_cu_seqlens
                                  .Arg<Buffer_Type>()      // kv_cu_seqlens
                                  .Arg<Buffer_Type>()      // q_seq_offsets
                                  .Arg<Buffer_Type>()      // k_seq_offsets
379
                                  .RemainingArgs()         // _cp_aux_args unused
380
381
382
383
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // softmax_aux
                                  .Ret<Buffer_Type>()      // rng_state
                                  .Ret<Buffer_Type>()      // workspace
384
                                  .Attrs(),
385
386
                              FFI_CudaGraph_Traits);

387
388
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
389
390
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
    size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
391
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
392
393
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right) {
394
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
395
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
396
  auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
397
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
398
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
399
  auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
400
  auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
401
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
402
403
  auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);

404
  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
405
406
  auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
407

408
409
410
  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

411
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
412
413
414
415
416
417
  auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);

  TensorWrapper query_workspace_tensor;
418
419
420
421

  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
422
423
424
425
426
427
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
428
429
430
  auto dummy_d_softmax_offset_tensor =
      TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
  NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
431
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
432
433
434
435
436
437
438
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
439
440
441
442
443
444
445
446
447
448
449
450

    nvte_fused_attn_bwd(
        q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
        doutput_tensor.data(),
        s_tensor.data(),  // not used for F16
        s_tensor.data(),  // not used for F16
        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
        dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
        kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
        dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
        dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
        window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
451
  }
452

453
454
  nvte_tensor_pack_destroy(&aux_input_tensors);

455
456
457
458
  auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

459
460
461
462
463
static void FusedAttnBackwardImpl(
    cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state,
    void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
    void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
464
465
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
    size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
466
467
468
469
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
    bool deterministic, int64_t window_size_left, int64_t window_size_right) {
  FUSED_ATTN_IMPL_COMMON_BLOCK;
470
471

  /* Input tensors */
472
  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
473
474
475
476
477
478
  auto output_tensor = TensorWrapper(output, output_shape, dtype);
  auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
479
480
481
  auto dummy_d_softmax_offset_tensor =
      TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
  NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
482
483
484
485

  /* Auxiliary tensors (propagated from the forward pass) */
  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);
486
  auto backend = nvte_get_fused_attn_backend(
487
      is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
488
      bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
489
      q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
490
      false, false);
491
492
493
  PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
                                     bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
                                     softmax_aux, rng_state, bias);
494
495

  /* Call the underly NVTE API */
496
497
498
499
500
501
502
503
504
505
506
  // Prepare Q, K, V pointers and shapes based on layout
  void *q_ptr = q;
  void *k_ptr = k;
  void *v_ptr = v;
  void *dq_ptr = dq;
  void *dk_ptr = dk;
  void *dv_ptr = dv;
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
  auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};

507
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
508
509
510
511
512
513
514
515
516
517
518
519
520
    // QKV packed in q: [batch*seqlen, 3, heads, dim]
    NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
    NVTE_CHECK(qk_head_dim == v_head_dim,
               "For QKV packed layout, qk_head_dim must equal v_head_dim");
    size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
    q_ptr = q;
    k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
    v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
    dq_ptr = dq;
    dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
    dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
    k_shape = q_shape;
    v_shape = q_shape;
521
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
    NVTE_CHECK(qk_head_dim == v_head_dim,
               "For KV packed layout, qk_head_dim must equal v_head_dim");
    size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
    q_ptr = q;
    k_ptr = k;
    v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
    dq_ptr = dq;
    dk_ptr = dk;
    dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
    // V has same shape as K since they're packed together
    v_shape = k_shape;
  }

  auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
  auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
  auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
  auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
  auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
  auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);

  if (is_ragged) {
    size_t dtype_size = typeToSize(dtype);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
      // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once
      cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream);
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      // Clear dq
      cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
      // For packed KV, dk contains both dk and dv - clear all at once
      cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream);
    } else {
      // All separate - clear each individually
      cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
      cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream);
      cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream);
558
    }
559
560
  }

561
562
563
564
565
566
567
568
569
570
571
  nvte_fused_attn_bwd(
      q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
      doutput_tensor.data(),
      s_tensor.data(),  // not used for F16
      s_tensor.data(),  // not used for F16
      &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
      dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
      q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
      scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
      window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);

572
  nvte_tensor_pack_destroy(&aux_input_tensors);
573
574
}

575
576
577
578
579
580
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
                                Buffer_Type v_buf, Buffer_Type bias_buf,
                                Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
                                Buffer_Type output_buf, Buffer_Type doutput_buf,
                                Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
                                Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
581
582
583
                                Variadic_Buffer_Type _unused_args, Result_Type dq_buf,
                                Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf,
                                Result_Type workspace_buf, Dictionary attrs) {
584
585
586
587
588
589
590
591
592
593
  FUSED_ATTN_FFI_GET_ATTRS;

  FusedAttnBackwardImpl(
      stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
      bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(),
      output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
      kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
      is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
      dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
      workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
594
595
596
      attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq,
      wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype,
      wkspace_dtype, is_training, deterministic, window_size_left, window_size_right);
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // q
                                  .Arg<Buffer_Type>()      // k
                                  .Arg<Buffer_Type>()      // v
                                  .Arg<Buffer_Type>()      // bias
                                  .Arg<Buffer_Type>()      // softmax_aux
                                  .Arg<Buffer_Type>()      // rng_state
                                  .Arg<Buffer_Type>()      // output
                                  .Arg<Buffer_Type>()      // doutput
                                  .Arg<Buffer_Type>()      // q_cu_seqlens
                                  .Arg<Buffer_Type>()      // kv_cu_seqlens
                                  .Arg<Buffer_Type>()      // q_seq_offsets
                                  .Arg<Buffer_Type>()      // k_seq_offsets
616
                                  .RemainingArgs()         // _cp_aux_args unused
617
618
619
620
621
622
623
624
                                  .Ret<Buffer_Type>()      // dq
                                  .Ret<Buffer_Type>()      // dk
                                  .Ret<Buffer_Type>()      // dv
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // workspace
                                  .Attrs(),
                              FFI_CudaGraph_Traits);

625
626
}  // namespace jax
}  // namespace transformer_engine