activation.cpp 25 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "../extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
namespace transformer_engine {
namespace jax {
16

17
18
19
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                    Result_Type output_buf, Result_Type colwise_output_buf,
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
20
                    Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
21
22
23
24
                    bool is_2x_int, ActivationConfig act_params) {
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
25
26
27
28
29
30
31
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
32
33
  auto *colwise_output = colwise_output_buf->untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
34
35

  auto input_dims = input_buf.dimensions();
36
  auto m = product(input_dims, 0, input_dims.size() - 2);
37
38
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
39
40
  auto act_len = input_dims[input_dims.size() - 2];
  auto is_2x = static_cast<bool>(is_2x_int);
41
  auto flatten_axis = output_buf->dimensions().size() - 1;  // output does not have act axis
42

43
44
45
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
46
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
47
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
48
49
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);

50
51
52
53
54
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

55
  if (is_fp8_dtype(out_dtype)) {
56
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
57
58
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
59
      nvte_memset(amax, 0, sizeof(float), stream);
60
61
62
63
64
65
66
67
68
69
70
71
72
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
    }
73
  }
74

75
  if (is_2x) {
76
77
78
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
79
80
81
82
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
83
84
85
86
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
87
88
89
90
91
92
93
94
95
96
97
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
    }
98
  }
99
100
101

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
102
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
103
104
      break;
    case NVTE_Activation_Type::GEGLU:
105
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
106
107
      break;
    case NVTE_Activation_Type::SILU:
108
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
109
110
      break;
    case NVTE_Activation_Type::SWIGLU:
111
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
112
113
      break;
    case NVTE_Activation_Type::RELU:
114
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
115
116
      break;
    case NVTE_Activation_Type::REGLU:
117
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
118
119
      break;
    case NVTE_Activation_Type::QGELU:
120
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
121
122
      break;
    case NVTE_Activation_Type::QGEGLU:
123
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
124
125
      break;
    case NVTE_Activation_Type::SRELU:
126
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
127
128
      break;
    case NVTE_Activation_Type::SREGLU:
129
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
130
      break;
131
132
133
134
    case NVTE_Activation_Type::CLAMPED_SWIGLU:
      nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
                          stream);
      break;
135
136
137
138
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
139

140
141
142
  return ffi_with_cuda_error_check();
}

143
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
144
145
146
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
147
                                  .Arg<Buffer_Type>()      // scale
148
                                  .Ret<Buffer_Type>()      // output
149
150
151
152
153
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<int64_t>("act_enum")
154
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
155
156
                                  .Attr<bool>("is_2x")
                                  .Attr<ActivationConfig>("act_params"),
157
                              FFI_CudaGraph_Traits);
158

159
160
161
162
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                              Result_Type output_buf, Result_Type colwise_output_buf,
                              Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                              Result_Type amax_buf, int64_t act_enum,
163
164
                              JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
                              ActivationConfig act_params) {
165
166
  return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
                             colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
167
                             act_enum, scaling_mode, is_2x_int, act_params);
168
169
170
171
172
173
174
175
176
177
178
179
180
181
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<int64_t>("act_enum")
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
182
183
                                  .Attr<bool>("is_2x")
                                  .Attr<ActivationConfig>("act_params"));
184

185
186
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
187
                                                   JAXX_Scaling_Mode scaling_mode, bool is_2x) {
188
189
190
191
192
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
193

194
195
196
197
198
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

199
200
201
202
203
204
205
206
207
208
209
210
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
211
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
212
213
214
215
216
217
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
218

219
  if (is_2x) {
220
221
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
222
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
223
224
225
226
227
228
229

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
230

231
  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
232
233
234
235
236
237
238
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
239
  // For now, all dbias_dact(-s) have the same workspace size
240
241
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
242

243
244
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
245
246
}

247
248
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
249
250
251
                                  Result_Type output_buf, Result_Type colwise_output_buf,
                                  Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                                  Result_Type amax_buf, Result_Type dbias_buf,
252
                                  Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
253
254
255
256
257
                                  int64_t act_enum, bool is_2x, bool is_dbias,
                                  ActivationConfig act_params) {
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
258
259
260
261
262
263
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
264
265
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
266

267
268
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
269

270
271
272
273
274
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

275
  auto *output = output_buf->untyped_data();
276
  auto *colwise_output = colwise_output_buf->untyped_data();
277
278
279
280
281
282
283
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
284
285
  // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
  auto act_len = act_input_dims[act_input_dims.size() - 2];
286
287
288
289
290
  NVTE_CHECK(act_len == 1 || act_len == 2,
             "The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
             "gated activation, got ",
             act_len);

291
292
293
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = input_dims.back();

294
295
296
297
298
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
  auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
299
300
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

301
302
303
304
  auto input_tensor =
      TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
  auto act_input_tensor = TensorWrapper(
      act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
305
306

  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
307
308
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
  if (is_fp8_dtype(out_dtype)) {
309
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
310
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
311
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
312
      nvte_memset(amax, 0, sizeof(float), stream);
313
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
314
315
316
317
318
319
320
321
322
323
324
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
325
    }
326
327
  }

328
  if (is_2x) {
329
330
331
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
332
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
333
334
335

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
336
337
338
339
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
340
341
342
343
344
345
346
347
348
349
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
350
    }
351
  }
352

353
354
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
355

356
  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
357
  NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
358
359
  NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
             "TE/common does not support delayed scaling for 2x with gated activations.");
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
424
425
426
427
      case NVTE_Activation_Type::CLAMPED_SWIGLU:
        nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                             swiglu_limit, swiglu_alpha, stream);
        break;
428
429
430
431
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
432
  }
433

434
435
436
  return ffi_with_cuda_error_check();
}

437
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
438
439
440
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
441
                                  .Arg<Buffer_Type>()      // act input
442
443
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
444
445
446
447
448
449
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
450
451
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
452
                                  .Attr<bool>("is_2x")
453
454
                                  .Attr<bool>("is_dbias")
                                  .Attr<ActivationConfig>("act_params"),
455
                              FFI_CudaGraph_Traits);
456

457
458
459
460
461
462
Error_Type DActLuDBiasQuantizeInitializeFFI(
    cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
    Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
    Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
    Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
    bool is_dbias, ActivationConfig act_params) {
463
464
465
  return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
                             act_input_buf, scale_buf, output_buf, colwise_output_buf,
                             scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
466
                             workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
                              DActLuDBiasQuantizeInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act input
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
                                  .Attr<bool>("is_2x")
486
487
                                  .Attr<bool>("is_dbias")
                                  .Attr<ActivationConfig>("act_params"));
488

489
490
}  // namespace jax
}  // namespace transformer_engine