activation.cpp 27.1 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "../extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
namespace transformer_engine {
namespace jax {
16

17
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
18
                    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
19
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
20
                    Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
21
22
                    JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
                    bool output_amax_when_no_scaling) {
23
24
25
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
26

27
28
29
30
31
32
33
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
34
  auto *colwise_output = colwise_output_buf->untyped_data();
35
36
37
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
  NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
38
39

  auto input_dims = input_buf.dimensions();
40
  auto m = product(input_dims, 0, input_dims.size() - 2);
41
42
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
43
  auto act_len = input_dims[input_dims.size() - 2];
44
  auto flatten_axis = output_buf->dimensions().size() - 1;  // output does not have act axis
45

46
47
48
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(act_len * n)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
49
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
50
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
51

52
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
53
54
55
56
  if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
      (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
    output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
  }
57

58
59
60
61
62
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

63
  if (is_fp8_dtype(out_dtype)) {
64
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
65
66
67
68
69
70
71
72
73
74
75
76
77
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
    }
78
  }
79

80
  if (is_quantize_2x2x(quantize_layout)) {
81
82
83
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
84
85
86
87
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
88
89
90
91
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
92
93
94
95
96
97
98
99
100
101
102
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
    }
103
  }
104
105
106

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
107
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
108
109
      break;
    case NVTE_Activation_Type::GEGLU:
110
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
111
      break;
Kim, Jin (Jay@SKT)'s avatar
Kim, Jin (Jay@SKT) committed
112
113
114
    case NVTE_Activation_Type::GLU:
      nvte_glu(input_tensor.data(), output_tensor.data(), stream);
      break;
115
    case NVTE_Activation_Type::SILU:
116
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
117
118
      break;
    case NVTE_Activation_Type::SWIGLU:
119
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
120
121
      break;
    case NVTE_Activation_Type::RELU:
122
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
123
124
      break;
    case NVTE_Activation_Type::REGLU:
125
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
126
127
      break;
    case NVTE_Activation_Type::QGELU:
128
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
129
130
      break;
    case NVTE_Activation_Type::QGEGLU:
131
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
132
133
      break;
    case NVTE_Activation_Type::SRELU:
134
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
135
136
      break;
    case NVTE_Activation_Type::SREGLU:
137
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
138
      break;
139
140
141
142
    case NVTE_Activation_Type::CLAMPED_SWIGLU:
      nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
                          stream);
      break;
143
144
145
146
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
147

148
149
150
  return ffi_with_cuda_error_check();
}

151
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
152
153
154
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
155
                                  .Arg<Buffer_Type>()      // scale
156
                                  .Arg<Buffer_Type>()      // amax
157
                                  .Ret<Buffer_Type>()      // output
158
159
160
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
161
                                  .Ret<Buffer_Type>()      // updated_amax
162
                                  .Attr<int64_t>("act_enum")
163
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
164
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
165
166
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"),
167
                              FFI_CudaGraph_Traits);
168

169
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
170
171
172
                              Buffer_Type amax_buf, Result_Type output_buf,
                              Result_Type colwise_output_buf, Result_Type scale_inv_buf,
                              Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
173
174
175
                              int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
                              JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
                              bool output_amax_when_no_scaling) {
176
177
  return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
                             output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
178
                             updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params,
179
                             output_amax_when_no_scaling);
180
181
182
183
184
185
186
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
187
                                  .Arg<Buffer_Type>()      // amax
188
189
190
191
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
192
                                  .Ret<Buffer_Type>()      // updated_amax
193
194
                                  .Attr<int64_t>("act_enum")
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
195
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
196
197
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"));
198

199
200
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
201
202
                                                   JAXX_Scaling_Mode scaling_mode,
                                                   JAXX_Quantize_Layout quantize_layout) {
203
204
205
206
207
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
208

209
210
211
212
213
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

214
215
216
217
218
219
220
221
222
223
224
225
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
226
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
227
228
229
230
231
232
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
233

234
  if (is_quantize_2x2x(quantize_layout)) {
235
236
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
237
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
238
239
240
241
242
243
244

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
245

246
  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
247
248
249
250
251
252
253
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
254
  // For now, all dbias_dact(-s) have the same workspace size
255
256
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
257

258
259
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
260
261
}

262
263
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
264
265
266
267
                                  Buffer_Type amax_buf, Result_Type output_buf,
                                  Result_Type colwise_output_buf, Result_Type scale_inv_buf,
                                  Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
                                  Result_Type dbias_buf, Result_Type workspace_buf,
268
269
270
                                  JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
                                  JAXX_Quantize_Layout quantize_layout, bool is_dbias,
                                  ActivationConfig act_params, bool output_amax_when_no_scaling) {
271
272
273
  // parameters for clamped swiglu used in GPT OSS
  auto swiglu_limit = act_params.clamped_swiglu.limit;
  auto swiglu_alpha = act_params.clamped_swiglu.alpha;
274

275
276
277
278
279
280
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
281
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
282
283
284
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
  NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
285

286
287
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
288

289
290
291
292
293
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

294
  auto *output = output_buf->untyped_data();
295
  auto *colwise_output = colwise_output_buf->untyped_data();
296
297
298
299
300
301
302
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
303
304
  // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
  auto act_len = act_input_dims[act_input_dims.size() - 2];
305
306
307
308
309
  NVTE_CHECK(act_len == 1 || act_len == 2,
             "The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
             "gated activation, got ",
             act_len);

310
311
312
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = input_dims.back();

313
314
315
316
317
  auto input_shape = std::vector<size_t>{m, static_cast<size_t>(n)};
  auto act_input_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_shape = std::vector<size_t>{m, static_cast<size_t>(n * act_len)};
  auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n * act_len), m};
  auto dbias_shape = std::vector<size_t>{static_cast<size_t>(n * act_len)};
318
319
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

320
321
322
323
  auto input_tensor =
      TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
  auto act_input_tensor = TensorWrapper(
      act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
324
325

  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
326
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
327
328
329
330
  if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
      (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
    output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
  }
331
  if (is_fp8_dtype(out_dtype)) {
332
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
333
334
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
335
336
337
338
339
340
341
342
343
344
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
345
    }
346
347
  }

348
  if (is_quantize_2x2x(quantize_layout)) {
349
350
351
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
352
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
353
354
355

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
356
357
358
359
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
360
361
362
363
364
365
366
367
368
369
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
370
    }
371
  }
372

373
374
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
375

376
  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
377
  NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
378
379
  NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING &&
               is_quantize_2x2x(quantize_layout) && act_len == 2),
380
             "TE/common does not support delayed scaling for 2x with gated activations.");
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
Kim, Jin (Jay@SKT)'s avatar
Kim, Jin (Jay@SKT) committed
433
434
435
      case NVTE_Activation_Type::GLU:
        nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
436
437
438
439
440
441
442
443
444
445
446
447
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
448
449
450
451
      case NVTE_Activation_Type::CLAMPED_SWIGLU:
        nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                             swiglu_limit, swiglu_alpha, stream);
        break;
452
453
454
455
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
456
  }
457

458
459
460
  return ffi_with_cuda_error_check();
}

461
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
462
463
464
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
465
                                  .Arg<Buffer_Type>()      // act input
466
                                  .Arg<Buffer_Type>()      // scale
467
                                  .Arg<Buffer_Type>()      // amax
468
                                  .Ret<Buffer_Type>()      // output
469
470
471
472
473
474
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
475
476
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
477
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
478
                                  .Attr<bool>("is_dbias")
479
480
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"),
481
                              FFI_CudaGraph_Traits);
482

483
484
Error_Type DActLuDBiasQuantizeInitializeFFI(
    cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
485
486
487
    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
    Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
488
489
    int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias,
    ActivationConfig act_params, bool output_amax_when_no_scaling) {
490
  return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
491
492
                             act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
                             scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
493
494
                             workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias,
                             act_params, output_amax_when_no_scaling);
495
496
497
498
499
500
501
502
503
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
                              DActLuDBiasQuantizeInitializeFFI,
                              FFI::Bind<FFI_Initialize>()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act input
                                  .Arg<Buffer_Type>()      // scale
504
                                  .Arg<Buffer_Type>()      // amax
505
506
507
508
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
509
                                  .Ret<Buffer_Type>()      // updated_amax
510
511
512
513
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
514
                                  .Attr<JAXX_Quantize_Layout>("quantize_layout")
515
                                  .Attr<bool>("is_dbias")
516
517
                                  .Attr<ActivationConfig>("act_params")
                                  .Attr<bool>("output_amax_when_no_scaling"));
518

519
520
}  // namespace jax
}  // namespace transformer_engine