activation.cpp 30.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
#include "extensions.h"
9
#include "transformer_engine/cast.h"
10
#include "transformer_engine/transpose.h"
11
#include "xla/ffi/api/c_api.h"
12
13
14
15

namespace transformer_engine {
namespace jax {

16
// TODO: We won't need this function anymore when we move to the new XLA custom calls
17
18
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
  switch (activation_enum) {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    case NVTE_Activation_Type::GELU:
      return 1;
    case NVTE_Activation_Type::GEGLU:
      return 2;
    case NVTE_Activation_Type::SILU:
      return 1;
    case NVTE_Activation_Type::SWIGLU:
      return 2;
    case NVTE_Activation_Type::RELU:
      return 1;
    case NVTE_Activation_Type::REGLU:
      return 2;
    case NVTE_Activation_Type::QGELU:
      return 1;
    case NVTE_Activation_Type::QGEGLU:
      return 2;
    case NVTE_Activation_Type::SRELU:
      return 1;
    case NVTE_Activation_Type::SREGLU:
      return 2;
39
40
41
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
42
      return -1;
43
44
45
46
  }
}

void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
47
               cudaStream_t stream, float *scale_inverse, float *amax, void *output,
48
               NVTE_Activation_Type act_enum, size_t act_len) {
49
50
51
52
53
54
  auto input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n};
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                     scale, scale_inverse);
  switch (act_enum) {
55
    case NVTE_Activation_Type::GELU:
56
57
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
      break;
58
    case NVTE_Activation_Type::GEGLU:
59
60
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
      break;
61
    case NVTE_Activation_Type::SILU:
62
63
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
      break;
64
    case NVTE_Activation_Type::SWIGLU:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
89
90
91
}

void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
92
93
  auto *input = buffers[0];
  auto *output = buffers[1];
94

95
96
97
98
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
99
  auto act_len = get_activation_len(act_enum);
100

101
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
102
            act_enum, act_len);
103
104
}

105
106
107
108
109
110
111
112
113
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
                    int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
114
  auto m = product(input_dims, 0, input_dims.size() - 2);
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  auto n = input_dims.back();
  auto act_len = input_dims.end()[-2];
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);

  ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type,
            act_len);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Ret<Buffer_Type>()      // output
130
131
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);
132

133
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
134
135
136
137
138
139
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  float *amax_out = reinterpret_cast<float *>(buffers[5]);
140
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
141

142
143
144
145
146
147
148
149
150
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
151
  auto act_len = get_activation_len(act_enum);
152

153
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
154
            act_enum, act_len);
155
156
}

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                       Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
                       Result_Type amax_out_buf, int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();
  float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");

  if (!use_fp8(out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }

  auto input_dims = input_buf.dimensions();
179
  auto m = product(input_dims, 0, input_dims.size() - 2);
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
  auto n = input_dims.back();
  auto act_len = input_dims.end()[-2];
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);

  ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type,
            act_len);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // amax_out
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);

202
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
203
204
205
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  auto *output = buffers[2];
206

207
208
209
210
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
211

212
213
214
215
  auto act_len = get_activation_len(act_enum);
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};
216

217
218
219
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
      nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
256
257
}

258
259
260
261
262
263
264
265
266
267
Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf,
                     Result_Type output_buf, int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
  auto *output = output_buf->untyped_data();

  auto act_input_dims = act_input_buf.dimensions();
268
269
  auto m = static_cast<size_t>(product(act_input_dims, 0, act_input_dims.size() - 2));
  auto n = static_cast<size_t>(act_input_dims.back());
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
  auto act_len = act_input_dims.end()[-2];

  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};

  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype));

  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  switch (act_type) {
    case NVTE_Activation_Type::GELU:
      nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act_input
                                  .Ret<Buffer_Type>()      // output
325
326
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);
327

328
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
329
330
331
332
333
334
                                                        DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto output_tensor = TensorWrapper();
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
351

352
  TensorWrapper dummy_workspace;
353

354
  // For now, all dbias_dact(-s) have the same workspace size
355
356
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
357

358
359
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
360
361
362
}

void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
363
364
365
366
367
368
369
370
371
372
373
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  auto *dbias = buffers[7];
  float *amax_out = reinterpret_cast<float *>(buffers[8]);
  void *workspace_ptr = buffers[9];
374

375
  const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
376
377
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
378
379
380
381
382
383
384
385
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
386

387
388
389
390
391
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
392

393
394
395
396
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
397
398
  output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
399
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
400

401
  auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
402

403
404
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
405
406
      nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace.data(), stream);
407
408
      break;
    case NVTE_Activation_Type::SILU:
409
410
      nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace.data(), stream);
411
412
      break;
    case NVTE_Activation_Type::RELU:
413
414
      nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace.data(), stream);
415
416
      break;
    case NVTE_Activation_Type::QGELU:
417
418
      nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 dbias_tensor.data(), workspace.data(), stream);
419
420
      break;
    case NVTE_Activation_Type::SRELU:
421
422
      nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 dbias_tensor.data(), workspace.data(), stream);
423
424
425
426
427
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
428
429
}

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                       Buffer_Type act_input_buf, Buffer_Type amax_buf,
                                       Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
                                       Result_Type output_buf, Result_Type output_trans_buf,
                                       Result_Type dbias_buf, Result_Type amax_out_buf,
                                       Result_Type workspace_buf, int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
  auto *output = output_buf->untyped_data();
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
  void *workspace = workspace_buf->untyped_data();
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
  if (!use_fp8(out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
  // n = ir_dz_shape[-1], ir_dz_shape == input_dims
  auto input_ranks = input_dims.size();
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = product(input_dims, input_ranks - 1, input_ranks);
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
476
477
  output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
478
479
480
481
482
483
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);

  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  switch (act_type) {
    case NVTE_Activation_Type::GELU:
484
485
      nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace_tensor.data(), stream);
486
487
      break;
    case NVTE_Activation_Type::SILU:
488
489
      nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace_tensor.data(), stream);
490
491
      break;
    case NVTE_Activation_Type::RELU:
492
493
      nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                dbias_tensor.data(), workspace_tensor.data(), stream);
494
495
      break;
    case NVTE_Activation_Type::QGELU:
496
497
      nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 dbias_tensor.data(), workspace_tensor.data(), stream);
498
499
      break;
    case NVTE_Activation_Type::SRELU:
500
501
      nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 dbias_tensor.data(), workspace_tensor.data(), stream);
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler, DActLuDBiasCastTransposeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act_input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // output_trans
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // amax_out
                                  .Ret<Buffer_Type>()      // workspace
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);

526
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
527
528
529
530
531
532
533
534
535
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  float *amax_out = reinterpret_cast<float *>(buffers[7]);
536

537
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
538
539
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
540
541
542
543
544
545
546
547
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
548

549
550
551
552
  auto input_shape = desc.shape.to_vector();
  auto act_input_shape = std::vector<size_t>{m, n * 2};
  auto output_shape = std::vector<size_t>{m, n * 2};
  auto output_trans_shape = std::vector<size_t>{n * 2, m};
553

554
555
556
557
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
558
559
  output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
560

561
562
563
  switch (act_enum) {
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
564
                                 stream);
565
566
567
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
568
                                  output_tensor.data(), stream);
569
570
571
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
572
                                 stream);
573
574
575
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
576
                                  output_tensor.data(), stream);
577
578
579
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
580
                                  output_tensor.data(), stream);
581
582
583
584
585
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
586
587
}

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                       Buffer_Type act_input_buf, Buffer_Type amax_buf,
                                       Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
                                       Result_Type output_buf, Result_Type output_trans_buf,
                                       Result_Type amax_out_buf, int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
  auto *output = output_buf->untyped_data();
  auto *output_trans = output_trans_buf->untyped_data();
  float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
  if (!use_fp8(out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto act_input_ranks = act_input_dims.size();
  auto m = product(act_input_dims, 0, act_input_ranks - 2);
  auto n = product(act_input_dims, act_input_ranks - 1, act_input_ranks);
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * 2};
  auto output_shape = std::vector<size_t>{m, n * 2};
  auto output_trans_shape = std::vector<size_t>{n * 2, m};

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
625
626
  output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
627
628
629
630
631

  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  switch (act_type) {
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
632
                                 stream);
633
634
635
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
636
                                  output_tensor.data(), stream);
637
638
639
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
640
                                 stream);
641
642
643
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
644
                                  output_tensor.data(), stream);
645
646
647
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
648
                                  output_tensor.data(), stream);
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler, DGatedActLuCastTransposeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act_input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // output_trans
                                  .Ret<Buffer_Type>()      // amax_out
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);

671
672
}  // namespace jax
}  // namespace transformer_engine