activation.cpp 26.9 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "../extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
namespace transformer_engine {
namespace jax {
16

17
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
18
                    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
19
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
20
                    Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
21
22
                    JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
                    bool output_amax_when_no_scaling) {
23
24
25
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
26

27
28
29
30
31
32
33
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
34
  auto *colwise_output = colwise_output_buf->untyped_data();
35
36
37
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
  NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
38
39

  auto input_dims = input_buf.dimensions();
40
  auto m = product(input_dims, 0, input_dims.size() - 2);
41
42
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
43
  auto act_len = input_dims[input_dims.size() - 2];
44
  auto flatten_axis = output_buf->dimensions().size() - 1;  // output does not have act axis
45

46
47
48
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
49
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
50
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
51

52
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
53
54
55
56
  if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
      (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
    output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
  }
57

58
59
60
61
62
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

63
  if (is_fp8_dtype(out_dtype)) {
64
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
65
66
67
68
69
70
71
72
73
74
75
76
77
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
    }
78
  }
79

80
  if (is_quantize_2x2x(quantize_layout)) {
81
82
83
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
84
85
86
87
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
88
89
90
91
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
92
93
94
95
96
97
98
99
100
101
102
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
    }
103
  }
104
105
106

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
107
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
108
109
      break;
    case NVTE_Activation_Type::GEGLU:
110
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
111
112
      break;
    case NVTE_Activation_Type::SILU:
113
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
114
115
      break;
    case NVTE_Activation_Type::SWIGLU:
116
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
117
118
      break;
    case NVTE_Activation_Type::RELU:
119
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
120
121
      break;
    case NVTE_Activation_Type::REGLU:
122
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
123
124
      break;
    case NVTE_Activation_Type::QGELU:
125
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
126
127
      break;
    case NVTE_Activation_Type::QGEGLU:
128
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
129
130
      break;
    case NVTE_Activation_Type::SRELU:
131
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
132
133
      break;
    case NVTE_Activation_Type::SREGLU:
134
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
135
      break;
136
137
138
139
    case NVTE_Activation_Type::CLAMPED_SWIGLU:
      nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
                          stream);
      break;
140
141
142
143
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
144

145
146
147
  return ffi_with_cuda_error_check();
}

148
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
149
150
151
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
152
                                  .Arg<Buffer_Type>()      // scale
153
                                  .Arg<Buffer_Type>()      // amax
154
                                  .Ret<Buffer_Type>()      // output
155
156
157
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
158
                                  .Ret<Buffer_Type>()      // updated_amax
159
                                  .Attr<int64_t>("act_enum")
160
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
161
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
162
163
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"),
164
                              FFI_CudaGraph_Traits);
165

166
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
167
168
169
                              Buffer_Type amax_buf, Result_Type output_buf,
                              Result_Type colwise_output_buf, Result_Type scale_inv_buf,
                              Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
170
171
172
                              int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
                              JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
                              bool output_amax_when_no_scaling) {
173
174
  return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
                             output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
175
                             updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
176
                             output_amax_when_no_scaling);
177
178
179
180
181
182
183
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
184
                                  .Arg<Buffer_Type>()      // amax
185
186
187
188
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
189
                                  .Ret<Buffer_Type>()      // updated_amax
190
191
                                  .Attr<int64_t>("act_enum")
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
192
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
193
194
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"));
195

196
197
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
198
199
                                                   JAXX_Scaling_Mode scaling_mode,
                                                   JAXX_Quantize_Layout quantize_layout) {
200
201
202
203
204
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
205

206
207
208
209
210
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

211
212
213
214
215
216
217
218
219
220
221
222
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
223
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
224
225
226
227
228
229
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
230

231
  if (is_quantize_2x2x(quantize_layout)) {
232
233
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
234
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
235
236
237
238
239
240
241

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
242

243
  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
244
245
246
247
248
249
250
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
251
  // For now, all dbias_dact(-s) have the same workspace size
252
253
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
254

255
256
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
257
258
}

259
260
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
261
262
263
264
                                  Buffer_Type amax_buf, Result_Type output_buf,
                                  Result_Type colwise_output_buf, Result_Type scale_inv_buf,
                                  Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
                                  Result_Type dbias_buf, Result_Type workspace_buf,
265
266
267
                                  JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
                                  JAXX_Quantize_Layout quantize_layout, bool is_dbias,
                                  ActivationConfig act_params, bool output_amax_when_no_scaling) {
268
269
270
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
271

272
273
274
275
276
277
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
278
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
279
280
281
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
  NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
282

283
284
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
285

286
287
288
289
290
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

291
  auto *output = output_buf->untyped_data();
292
  auto *colwise_output = colwise_output_buf->untyped_data();
293
294
295
296
297
298
299
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
300
301
  // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
  auto act_len = act_input_dims[act_input_dims.size() - 2];
302
303
304
305
306
  NVTE_CHECK(act_len == 1 || act_len == 2,
             "The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
             "gated activation, got ",
             act_len);

307
308
309
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = input_dims.back();

310
311
312
313
314
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
  auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
315
316
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

317
318
319
320
  auto input_tensor =
      TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
  auto act_input_tensor = TensorWrapper(
      act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
321
322

  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
323
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
324
325
326
327
  if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
      (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
    output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
  }
328
  if (is_fp8_dtype(out_dtype)) {
329
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
330
331
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
332
333
334
335
336
337
338
339
340
341
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
342
    }
343
344
  }

345
  if (is_quantize_2x2x(quantize_layout)) {
346
347
348
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
349
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
350
351
352

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
353
354
355
356
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
357
358
359
360
361
362
363
364
365
366
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
367
    }
368
  }
369

370
371
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
372

373
  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
374
  NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
375
376
  NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
               is_quantize_2x2x(quantize_layout) && act_len == 2),
377
             "TE/common does not support delayed scaling for 2x with gated activations.");
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
442
443
444
445
      case NVTE_Activation_Type::CLAMPED_SWIGLU:
        nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                             swiglu_limit, swiglu_alpha, stream);
        break;
446
447
448
449
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
450
  }
451

452
453
454
  return ffi_with_cuda_error_check();
}

455
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
456
457
458
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
459
                                  .Arg<Buffer_Type>()      // act input
460
                                  .Arg<Buffer_Type>()      // scale
461
                                  .Arg<Buffer_Type>()      // amax
462
                                  .Ret<Buffer_Type>()      // output
463
464
465
466
467
468
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
469
470
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
471
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
472
                                  .Attr<bool>("is_dbias")
473
474
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"),
475
                              FFI_CudaGraph_Traits);
476

477
478
Error_Type DActLuDBiasQuantizeInitializeFFI(
    cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
479
480
481
    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
    Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
482
483
    int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
    ActivationConfig act_params, bool output_amax_when_no_scaling) {
484
  return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
485
486
                             act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
                             scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
487
488
                             workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
                             act_params, output_amax_when_no_scaling);
489
490
491
492
493
494
495
496
497
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
                              DActLuDBiasQuantizeInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act input
                                  .Arg<Buffer_Type>()      // scale
498
                                  .Arg<Buffer_Type>()      // amax
499
500
501
502
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
503
                                  .Ret<Buffer_Type>()      // updated_amax
504
505
506
507
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
508
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
509
                                  .Attr<bool>("is_dbias")
510
511
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"));
512

513
514
}  // namespace jax
}  // namespace transformer_engine