utils.cu 17.7 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "../common.h"
8
#include "transformer_engine/fused_attn.h"
cyanguwa's avatar
cyanguwa committed
9
10
11
12
13
14
15
16
#include "utils.h"

namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

// get matrix strides based on matrix type
17
18
19
20
21
22
void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
                           int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix) {
  constexpr int batch_dim_idx = 0;
  constexpr int head_dim_idx = 1;
  constexpr int seqlen_dim_idx = 2;
  constexpr int hidden_dim_idx = 3;
cyanguwa's avatar
cyanguwa committed
23

24
25
  constexpr int seqlen_transpose_dim_idx = 3;
  constexpr int hidden_transpose_dim_idx = 2;
cyanguwa's avatar
cyanguwa committed
26

27
28
  constexpr int seqlen_q_dim_idx = 2;
  constexpr int seqlen_kv_dim_idx = 3;
cyanguwa's avatar
cyanguwa committed
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
  switch (layout) {
    case NVTE_QKV_Layout::NVTE_SB3HD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = 3 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * 3 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = 3 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_SBH3D:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = 3 * h * d;
        strideA[head_dim_idx] = 3 * d;
        strideA[seqlen_dim_idx] = b * 3 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = 3 * h * d;
        strideA[head_dim_idx] = 3 * d;
        strideA[seqlen_transpose_dim_idx] = b * 3 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = 2 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * 2 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = 2 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
      if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = 2 * h * d;
        strideA[head_dim_idx] = 2 * d;
        strideA[seqlen_dim_idx] = b * 2 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = 2 * h * d;
        strideA[head_dim_idx] = 2 * d;
        strideA[seqlen_transpose_dim_idx] = b * 2 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = b * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = b * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_BS3HD:
    case NVTE_QKV_Layout::NVTE_T3HD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = s_q * 3 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = 3 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = s_q * 3 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = 3 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
        strideA[batch_dim_idx] = s_q * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_BSH3D:
    case NVTE_QKV_Layout::NVTE_TH3D:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = s_q * 3 * h * d;
        strideA[head_dim_idx] = 3 * d;
        strideA[seqlen_dim_idx] = 3 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = s_q * 3 * h * d;
        strideA[head_dim_idx] = 3 * d;
        strideA[seqlen_transpose_dim_idx] = 3 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix) {
        strideA[batch_dim_idx] = s_q * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
    case NVTE_QKV_Layout::NVTE_THD_T2HD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = s_kv * 2 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = 2 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = s_kv * 2 * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = 2 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = s_q * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
    case NVTE_QKV_Layout::NVTE_THD_TH2D:
      if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = s_kv * 2 * h * d;
        strideA[head_dim_idx] = 2 * d;
        strideA[seqlen_dim_idx] = 2 * h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = s_kv * 2 * h * d;
        strideA[head_dim_idx] = 2 * d;
        strideA[seqlen_transpose_dim_idx] = 2 * h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = s_q * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      }
      break;
    case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
    case NVTE_QKV_Layout::NVTE_THD_THD_THD:
      if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
          (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
        strideA[batch_dim_idx] = s_q * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
        strideA[batch_dim_idx] = s_kv * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_dim_idx] = h * d;
        strideA[hidden_dim_idx] = 1;
      } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
                 (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
        strideA[batch_dim_idx] = s_kv * h * d;
        strideA[head_dim_idx] = d;
        strideA[seqlen_transpose_dim_idx] = h * d;
        strideA[hidden_transpose_dim_idx] = 1;
      }
      break;
  }
243

244
245
246
247
248
249
  if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
    strideA[seqlen_kv_dim_idx] = 1;
    strideA[seqlen_q_dim_idx] = s_kv;
    strideA[head_dim_idx] = s_q * s_kv;
    strideA[batch_dim_idx] = h * s_q * s_kv;
  }
cyanguwa's avatar
cyanguwa committed
250
251
}

252
253
254
255
256
bool allowAllConfig(cudnnBackendDescriptor_t engine_config) {
  (void)engine_config;
  return false;
}

257
258
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim,
                                     int64_t const *stride, bool is_virtual, bool is_value) {
259
  int nbDims = 4;
260
261
  auto tensor_created =
      cudnn_frontend::TensorBuilder()
262
263
264
265
266
267
268
269
270
271
272
273
          .setDim(nbDims, dim)
          .setStride(nbDims, stride)
          .setId(id)
          .setAlignment(16)  // 16B alignment is needed to run a tensor core engine
          .setDataType(type)
          .setVirtual(is_virtual)
          .setByValue(is_value)
          .build();
  return tensor_created;
}

cudnn_frontend::Tensor tensor_create_with_offset(
274
275
    cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual,
    bool is_value, std::shared_ptr<cudnn_frontend::Tensor> raggedOffset) {
276
  int nbDims = 4;
277
278
  auto tensor_created =
      cudnn_frontend::TensorBuilder()
279
280
281
282
283
284
285
286
287
288
289
290
          .setDim(nbDims, dim)
          .setStride(nbDims, stride)
          .setId(id)
          .setAlignment(16)  // 16B alignment is needed to run a tensor core engine
          .setDataType(type)
          .setVirtual(is_virtual)
          .setByValue(is_value)
          .setRaggedOffset(raggedOffset)
          .build();
  return tensor_created;
}

291
292
293
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode) {
  auto pw_desc_created =
      cudnn_frontend::PointWiseDescBuilder().setMode(mode).setComputeType(type).build();
294
295
296
  return pw_desc_created;
}

297
298
299
300
301
302
303
304
305
cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                             cudnn_frontend::Tensor const &yDesc,
                                             cudnn_frontend::PointWiseDesc const &pwDesc) {
  auto pw_op_created =
      cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
          .setxDesc(xDesc)
          .setyDesc(yDesc)
          .setpwDesc(pwDesc)
          .build();
306
307
308
  return pw_op_created;
}

309
310
311
312
313
314
315
316
317
318
319
cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                              cudnn_frontend::Tensor const &bDesc,
                                              cudnn_frontend::Tensor const &yDesc,
                                              cudnn_frontend::PointWiseDesc const &pwDesc) {
  auto pw_op_created =
      cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
          .setxDesc(xDesc)
          .setbDesc(bDesc)
          .setyDesc(yDesc)
          .setpwDesc(pwDesc)
          .build();
320
321
322
  return pw_op_created;
}

323
324
325
326
327
328
329
330
331
332
333
334
335
cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                               cudnn_frontend::Tensor const &bDesc,
                                               cudnn_frontend::Tensor const &tDesc,
                                               cudnn_frontend::Tensor const &yDesc,
                                               cudnn_frontend::PointWiseDesc const &pwDesc) {
  auto pw_op_created =
      cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
          .setxDesc(xDesc)
          .setbDesc(bDesc)
          .settDesc(tDesc)
          .setyDesc(yDesc)
          .setpwDesc(pwDesc)
          .build();
336
337
338
  return pw_op_created;
}

cyanguwa's avatar
cyanguwa committed
339
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
340
341
342
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q,
                                      int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
                                      int32_t *o_ragged_offset) {
cyanguwa's avatar
cyanguwa committed
343
344
345
346
347
348
349
350
351
  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < b) {
    actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid];
  }
  if (tid < b + 1) {
    qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d;
    o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d;
  }
}
352
353

// convert cu_seqlens to actual_seqlens
354
355
356
__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens,
                                             int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
                                             int32_t *kv_seqlens) {
357
358
359
360
361
362
  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < b) {
    q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid];
    kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
  }
}
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
                                             size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
                                             int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
                                             int32_t *offsets_k, int32_t *offsets_v,
                                             int32_t *offsets_o) {
  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < b + 1) {
    offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
    switch (layout_group) {
      case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
        offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
        offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
        offsets_v[tid] = offsets_k[tid];
        break;
      case NVTE_QKV_Layout_Group::NVTE_3HD:
      case NVTE_QKV_Layout_Group::NVTE_H3D:
        offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
        offsets_k[tid] = offsets_q[tid];
        offsets_v[tid] = offsets_q[tid];
        break;
      case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
      case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
        offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
        offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
        offsets_v[tid] = offsets_k[tid];
        break;
    }
  }
}

cyanguwa's avatar
cyanguwa committed
395
396
397
398
399
400
}  // namespace fused_attn

// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
401
402
403
404
    case DType::kInt32:
      return CUDNN_DATA_INT32;
    case DType::kInt64:
      return CUDNN_DATA_INT64;
cyanguwa's avatar
cyanguwa committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    case DType::kFloat16:
      return CUDNN_DATA_HALF;
    case DType::kFloat32:
      return CUDNN_DATA_FLOAT;
    case DType::kBFloat16:
      return CUDNN_DATA_BFLOAT16;
    case DType::kFloat8E4M3:
      return CUDNN_DATA_FP8_E4M3;
    case DType::kFloat8E5M2:
      return CUDNN_DATA_FP8_E5M2;
    default:
      NVTE_ERROR("Invalid cuDNN data type. \n");
  }
}
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kInt32:
      return cudnn_frontend::DataType_t::INT32;
    case DType::kInt64:
      return cudnn_frontend::DataType_t::INT64;
    case DType::kFloat16:
      return cudnn_frontend::DataType_t::HALF;
    case DType::kFloat32:
      return cudnn_frontend::DataType_t::FLOAT;
    case DType::kBFloat16:
      return cudnn_frontend::DataType_t::BFLOAT16;
    case DType::kFloat8E4M3:
      return cudnn_frontend::DataType_t::FP8_E4M3;
    case DType::kFloat8E5M2:
      return cudnn_frontend::DataType_t::FP8_E5M2;
    default:
      NVTE_ERROR("Invalid cuDNN data type. \n");
  }
}
cyanguwa's avatar
cyanguwa committed
442
}  // namespace transformer_engine