attention.cpp 39.5 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "../extensions.h"
8
#include "transformer_engine/fused_attn.h"
9
#include "transformer_engine/transformer_engine.h"
10
11
12
13
14
15
16
17
18

namespace transformer_engine {
namespace jax {

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_attn_heads, size_t kv_attn_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
19
20
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right) {
21
22
23
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
24
      head_dim, head_dim, window_size_left, window_size_right);
25
  return backend;
26
27
28
29
30
31
32
33
}

/*
    NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused
    attention forward kernels in:
        - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812
        - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359
*/
34
35
36
37
void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                       const size_t bias_batch, const size_t attn_heads,
                                       const size_t bias_heads, const size_t q_max_seqlen,
                                       const size_t kv_max_seqlen, DType dtype,
38
39
40
                                       NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
                                       void *softmax_buf, void *rng_state_buf = nullptr,
                                       void *bias_buf = nullptr) {
41
42
43
  // all backends need softmax but expect different shapes/dtypes
  // start with the max512 sequence length softmax shape/dtype and correct later
  tensor_pack->size = 1;
44
45
46
47
48
49
50
51
52
  NVTETensor &softmax_aux = tensor_pack->tensors[0];
  NVTEBasicTensor softmax_aux_data;
  softmax_aux_data.data_ptr = softmax_buf;
  softmax_aux_data.shape.ndim = 4;
  softmax_aux_data.shape.data[0] = input_batch;
  softmax_aux_data.shape.data[1] = attn_heads;
  softmax_aux_data.shape.data[2] = q_max_seqlen;
  softmax_aux_data.shape.data[3] = kv_max_seqlen;
  softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
53
54
55
56

  // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
    tensor_pack->size = 2;
57
58
59
60
61
62
63
    NVTETensor &rng_state_aux = tensor_pack->tensors[1];
    NVTEBasicTensor rng_state_aux_data;
    rng_state_aux_data.data_ptr = rng_state_buf;
    rng_state_aux_data.shape = {};
    rng_state_aux_data.shape.ndim = 2;
    rng_state_aux_data.dtype = static_cast<NVTEDType>(DType::kInt64);
    nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data);
64
    // correct softmax shape/dtype
65
66
    softmax_aux_data.shape.data[3] = 1;  // {B,H,Qs,Ks} -> {B,H,Qs,1}
    softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
67
68
69
70

    // include bias if enabled
    if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
      tensor_pack->size = 3;
71
72
73
74
75
76
77
78
79
80
      NVTETensor &bias_aux = tensor_pack->tensors[2];
      NVTEBasicTensor bias_aux_data;
      bias_aux_data.data_ptr = bias_buf;
      bias_aux_data.shape.ndim = 4;
      bias_aux_data.shape.data[0] = bias_batch;
      bias_aux_data.shape.data[1] = bias_heads;
      bias_aux_data.shape.data[2] = q_max_seqlen;
      bias_aux_data.shape.data[3] = kv_max_seqlen;
      bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
      nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
81
    }
82
  }
83
  nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
84
85
86
87
88
89
90
91
92
93
}

/*
    NOTE: Backward fused attention kernels accept auxiliary tensors as explicit function arguments
    instead of an NVTETensorPack and nvte_fused_attn_bwd() API does all the logic for pulling the
    necessary tensors out of the tensor pack for the active kernel. That means we can just dump
    everything we got into the tensor pack and not worry about its sizing for the backward pass.

    TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()?
*/
94
95
96
97
void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
                                        const size_t bias_batch, const size_t attn_heads,
                                        const size_t bias_heads, const size_t q_max_seqlen,
                                        const size_t kv_max_seqlen, DType dtype,
98
99
                                        NVTE_Fused_Attn_Backend backend, void *softmax_buf,
                                        void *rng_state_buf, void *bias_buf) {
100
101
102
103
  // Backward calls put everything into the tensor pack for every backend
  // so we set dummy bias_type and backend choices here to follow the correct code path
  auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS;
  auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
104
105
106
  PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads,
                                    q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type,
                                    dummy_backend, softmax_buf, rng_state_buf, bias_buf);
107
108
109

  // correct softmax shape for max512 sequence length kernel
  if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
110
111
112
113
114
    NVTEBasicTensor softmax_aux_data =
        nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData);
    softmax_aux_data.shape.data[3] = kv_max_seqlen;  // {B,H,Qs,1} -> {B,H,Qs,Ks}
    softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
    nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data);
115
  }
116
117
118
119
120
121
}

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
122
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
123
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);

  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
  auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
  auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);

  auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
148
  auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
149
150
151
152
153

  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);

  TensorWrapper query_workspace_tensor;
154
155
156
157
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
158
159
160
161
162
163
164
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
165
166
167
168
169
170
171
172
173
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
      NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
174
175
176
177
178
179
      nvte_fused_attn_fwd_qkvpacked(
          qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
          &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
          dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor,
          dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
          window_size_right, query_workspace_tensor.data(), nullptr);
180
181
182
183
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_fwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
          &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
184
185
186
187
          ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
          dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
          kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
          mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
188
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
189
190
191
192
      nvte_fused_attn_fwd(
          q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
          o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
          kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
193
          dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
194
195
196
          dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
          dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
          window_size_right, query_workspace_tensor.data(), nullptr);
197
198
199
    } else {
      NVTE_ERROR("Unsupported QKVLayout.");
    }
200
201
  }

202
203
  nvte_tensor_pack_destroy(&aux_output_tensors);

204
205
  auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
206
207
}

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#define FUSED_ATTN_IMPL_COMMON_BLOCK                                                          \
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;              \
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \
  size_t num_segments = input_batch;                                                          \
  if (is_ragged) {                                                                            \
    auto cudnn_runtime_version = cudnnGetVersion();                                           \
    if (cudnn_runtime_version >= 90300) {                                                     \
      num_segments = input_batch * max_segments_per_seq;                                      \
    } else {                                                                                  \
      size_t runtime_num_segments_q = nvte_get_runtime_num_segments(                          \
          q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);                       \
      size_t runtime_num_segments_kv = nvte_get_runtime_num_segments(                         \
          kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);                     \
      NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);                          \
      NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);               \
      num_segments = runtime_num_segments_q;                                                  \
    }                                                                                         \
  }                                                                                           \
  std::vector<size_t> seq_shape{num_segments + 1};                                            \
  auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32);           \
  auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32);         \
  auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32);         \
  auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32);         \
  auto workspace_tensor =                                                                     \
      TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);             \
233
234
  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);

235
static void FusedAttnForwardImpl(
236
237
238
239
240
241
242
243
    cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens,
    void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
    void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
    size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
    size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
    bool deterministic, int64_t window_size_left, int64_t window_size_right) {
244
  FUSED_ATTN_IMPL_COMMON_BLOCK;
245

246
247
248
  /* Input tensors */
  auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

249
  if (is_ragged) {
250
251
    auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
    cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
Reese Wang's avatar
Reese Wang committed
252
253
254
255

    // Memset to 0xF0 for filling large negative numbers
    auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads;
    cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream);
256
257
  }

258
259
260
261
262
263
264
  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto o_tensor = TensorWrapper(output, o_shape, dtype);

  /* Prepare RNG state */
  auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
265
266
267
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
268
      head_dim, head_dim, window_size_left, window_size_right);
269
  nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
270
271
272
273

  /* Auxiliary tensors (to be propagated to the backward pass later) */
  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);
274
275
276
  PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
                                    bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
                                    backend, softmax_aux);
277

278
  /* Call the underlying NVTE API */
279
  auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
280
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
281
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
282
283
284
285
286
287
288
    auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
    nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
                                  o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
                                  q_seq_offsets_tensor.data(), rng_state_tensor.data(),
                                  q_max_seqlen, is_training, scaling_factor, dropout_probability,
                                  qkv_layout, bias_type, mask_type, window_size_left,
                                  window_size_right, workspace_tensor.data(), stream);
289
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
290
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
291
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
292
293
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
294
295
296
    nvte_fused_attn_fwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
        &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
297
298
299
300
        q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
        dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
        is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
        window_size_left, window_size_right, workspace_tensor.data(), stream);
301
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
302
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
303
304
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
    auto v_shape = k_shape;
305
306
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
307
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
308
309
310
311
312
313
314
    nvte_fused_attn_fwd(
        q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
        o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
        kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
        dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
        q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
        bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
315
316
317
318
319
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_output_tensors);
320
321
}

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#define FUSED_ATTN_FFI_GET_ATTRS                                                        \
  size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch");                   \
  size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch");                     \
  size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen");                 \
  size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen");               \
  size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads");                     \
  size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups");             \
  size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads");                     \
  size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim");                         \
  size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
  auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left");           \
  auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right");         \
  float scaling_factor = get_attr_value<double>(attrs, "scaling_factor");               \
  float dropout_probability = get_attr_value<double>(attrs, "dropout_probability");     \
  NVTE_Bias_Type bias_type =                                                            \
      static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type"));         \
  NVTE_Mask_Type mask_type =                                                            \
      static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type"));         \
  NVTE_QKV_Layout qkv_layout =                                                          \
      static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout"));       \
  bool is_training = get_attr_value<bool>(attrs, "is_training");                        \
  bool deterministic = get_attr_value<bool>(attrs, "deterministic");                    \
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;        \
  size_t wkspace_size = product(workspace_buf->dimensions());                           \
  DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());                 \
  DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

349
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
350
                               Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf,
351
352
                               Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
                               Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
353
                               Variadic_Buffer_Type _unused_args, Result_Type output_buf,
354
355
                               Result_Type softmax_aux_buf, Result_Type rng_state_buf,
                               Result_Type workspace_buf, Dictionary attrs) {
356
  FUSED_ATTN_FFI_GET_ATTRS;
357

358
359
  FusedAttnForwardImpl(
      stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
360
361
362
363
364
365
366
367
      bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
      kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
      is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
      softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
      input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
      head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
      mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
      window_size_right);
368
369
370
371
372
373
374
375
376
377
378

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // q
                                  .Arg<Buffer_Type>()      // k
                                  .Arg<Buffer_Type>()      // v
                                  .Arg<Buffer_Type>()      // bias
379
                                  .Arg<Buffer_Type>()      // seed_buf
380
381
382
383
                                  .Arg<Buffer_Type>()      // q_cu_seqlens
                                  .Arg<Buffer_Type>()      // kv_cu_seqlens
                                  .Arg<Buffer_Type>()      // q_seq_offsets
                                  .Arg<Buffer_Type>()      // k_seq_offsets
384
                                  .RemainingArgs()         // _cp_aux_args unused
385
386
387
388
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // softmax_aux
                                  .Ret<Buffer_Type>()      // rng_state
                                  .Ret<Buffer_Type>()      // workspace
389
                                  .Attrs(),
390
391
                              FFI_CudaGraph_Traits);

392
393
394
395
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
396
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
397
398
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right) {
399
400
401
402
  // For qkv_packed
  auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
  auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
403

404
405
  // For kv_packed
  auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
406
  auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
407
408
409
410
411
412
413
  auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
  auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
  auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
  auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);

  // For separate q, k, v
  auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
414
  auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
415
416
  auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
  auto v_shape = k_shape;
417
  auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
418
419
420
  auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);

  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
421
422
  auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
423

424
425
426
  // F16 doesn't use this tensor
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);

427
  auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
428
429
430
431
432
433
  auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);

  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);

  TensorWrapper query_workspace_tensor;
434
435
436
437
438

  auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
  auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
  // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
  size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
439
440
441
442
443
444
445
  size_t min_num_segments = input_batch;
  auto cudnn_runtime_version = cudnnGetVersion();
  if (is_ragged && cudnn_runtime_version >= 90300) {
    // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
    min_num_segments = input_batch * max_segments_per_seq;
  }
  for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
446
447
448
449
450
451
452
453
    // the last one is the largest which will be the returned workspace size
    auto q_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto kv_cu_seqlens_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    auto dummy_ragged_offset_tensor =
        TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
    if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
454
455
456
457
458
459
      nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
                                    s_tensor.data(),  // not used for F16
                                    s_tensor.data(),  // not used for F16
                                    &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                    q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
                                    q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
460
461
                                    bias_type, mask_type, window_size_left, window_size_right,
                                    deterministic, query_workspace_tensor.data(), nullptr);
462
463
464
465
466
467
468
469
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
      nvte_fused_attn_bwd_kvpacked(
          q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
          s_tensor.data(),  // not used for F16
          s_tensor.data(),  // not used for F16
          &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
          q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
          dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
470
471
472
          kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
          window_size_left, window_size_right, deterministic, query_workspace_tensor.data(),
          nullptr);
473
474
475
476
477
478
479
480
481
    } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
      nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                          doutput_tensor.data(),
                          s_tensor.data(),  // not used for F16
                          s_tensor.data(),  // not used for F16
                          &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                          dbias_tensor.data(), q_cu_seqlens_tensor.data(),
                          kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
                          dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
482
483
484
                          scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
                          window_size_left, window_size_right, deterministic,
                          query_workspace_tensor.data(), nullptr);
485
486
487
488
    } else {
      NVTE_ERROR("Unsupported qkv_layout.");
    }
  }
489

490
491
  nvte_tensor_pack_destroy(&aux_input_tensors);

492
493
494
495
  auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
  return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}

496
497
498
499
500
501
502
503
504
505
506
static void FusedAttnBackwardImpl(
    cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state,
    void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
    void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
    bool deterministic, int64_t window_size_left, int64_t window_size_right) {
  FUSED_ATTN_IMPL_COMMON_BLOCK;
507
508
509
510
511
512
513
514
515
516
517
518
519

  /* Input tensors */
  auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
  auto output_tensor = TensorWrapper(output, output_shape, dtype);
  auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);

  /* Output tensors */
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);  // not used in F16
  auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);

  /* Auxiliary tensors (propagated from the forward pass) */
  NVTETensorPack aux_input_tensors;
  nvte_tensor_pack_create(&aux_input_tensors);
520
521
522
  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
      mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
523
      head_dim, head_dim, window_size_left, window_size_right);
524
525
526
  PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
                                     bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
                                     softmax_aux, rng_state, bias);
527
528

  /* Call the underly NVTE API */
529
  if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
530
    auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
531
532
    auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
    auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
533
    if (is_ragged) {
534
535
      cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype),
                      stream);
536
    }
537
538
539
540
541
542
543
544
    nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
                                  s_tensor.data(),  // not used for F16
                                  s_tensor.data(),  // not used for F16
                                  &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
                                  q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
                                  q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
                                  bias_type, mask_type, window_size_left, window_size_right,
                                  deterministic, workspace_tensor.data(), stream);
545
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
546
547
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
548
549
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
550
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
551
    auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype);
552
    if (is_ragged) {
553
554
555
      cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype),
                      stream);
556
    }
557
558
559
560
561
    nvte_fused_attn_bwd_kvpacked(
        q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
        s_tensor.data(),  // not used for F16
        s_tensor.data(),  // not used for F16
        &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
562
563
        q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
564
565
        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
        deterministic, workspace_tensor.data(), stream);
566
  } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
567
568
569
    auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
    auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
    auto v_shape = k_shape;
570
571
    auto q_tensor = TensorWrapper(q, q_shape, dtype);
    auto k_tensor = TensorWrapper(k, k_shape, dtype);
572
573
574
575
    auto v_tensor = TensorWrapper(v, v_shape, dtype);
    auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
    auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
    auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
576
    if (is_ragged) {
577
578
579
      cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
      cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
580
    }
581
582
583
584
    nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
                        doutput_tensor.data(),
                        s_tensor.data(),  // not used for F16
                        s_tensor.data(),  // not used for F16
585
586
                        &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
                        dbias_tensor.data(), q_cu_seqlens_tensor.data(),
587
588
                        kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
                        k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
589
590
                        dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
                        window_size_right, deterministic, workspace_tensor.data(), stream);
591
592
593
594
595
  } else {
    NVTE_ERROR("Unsupported qkv_layout.");
  }

  nvte_tensor_pack_destroy(&aux_input_tensors);
596
597
}

598
599
600
601
602
603
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
                                Buffer_Type v_buf, Buffer_Type bias_buf,
                                Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
                                Buffer_Type output_buf, Buffer_Type doutput_buf,
                                Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
                                Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
604
605
606
                                Variadic_Buffer_Type _unused_args, Result_Type dq_buf,
                                Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf,
                                Result_Type workspace_buf, Dictionary attrs) {
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
  FUSED_ATTN_FFI_GET_ATTRS;

  FusedAttnBackwardImpl(
      stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
      bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(),
      output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
      kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
      is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
      dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
      workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
      attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
      scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
      is_training, deterministic, window_size_left, window_size_right);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // q
                                  .Arg<Buffer_Type>()      // k
                                  .Arg<Buffer_Type>()      // v
                                  .Arg<Buffer_Type>()      // bias
                                  .Arg<Buffer_Type>()      // softmax_aux
                                  .Arg<Buffer_Type>()      // rng_state
                                  .Arg<Buffer_Type>()      // output
                                  .Arg<Buffer_Type>()      // doutput
                                  .Arg<Buffer_Type>()      // q_cu_seqlens
                                  .Arg<Buffer_Type>()      // kv_cu_seqlens
                                  .Arg<Buffer_Type>()      // q_seq_offsets
                                  .Arg<Buffer_Type>()      // k_seq_offsets
639
                                  .RemainingArgs()         // _cp_aux_args unused
640
641
642
643
644
645
646
647
                                  .Ret<Buffer_Type>()      // dq
                                  .Ret<Buffer_Type>()      // dk
                                  .Ret<Buffer_Type>()      // dv
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // workspace
                                  .Attrs(),
                              FFI_CudaGraph_Traits);

648
649
}  // namespace jax
}  // namespace transformer_engine