activation.cpp 20.2 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
namespace transformer_engine {
namespace jax {
16

17
18
19
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                    Result_Type output_buf, Result_Type colwise_output_buf,
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
20
                    Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
21
                    bool is_2x_int) {
22
23
24
25
26
27
28
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
29
30
  auto *colwise_output = colwise_output_buf->untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
31
32

  auto input_dims = input_buf.dimensions();
33
  auto m = product(input_dims, 0, input_dims.size() - 2);
34
35
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
36
37
  auto act_len = input_dims[input_dims.size() - 2];
  auto is_2x = static_cast<bool>(is_2x_int);
38
  auto flatten_axis = output_buf->dimensions().size() - 1;  // output does not have act axis
39

40
41
  auto input_shape = std::vector<size_t>{m, act_len * n};
  auto output_shape = std::vector<size_t>{m, n};
42
  auto output_trans_shape = std::vector<size_t>{n, m};
43
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
44
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
45
46
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);

47
48
49
50
51
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

52
  if (is_fp8_dtype(out_dtype)) {
53
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
54
55
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
56
      nvte_memset(amax, 0, sizeof(float), stream);
57
58
59
60
61
62
63
64
65
66
67
68
69
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
    }
70
  }
71

72
  if (is_2x) {
73
74
75
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
76
77
78
79
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
80
81
82
83
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
84
85
86
87
88
89
90
91
92
93
94
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
    }
95
  }
96
97
98

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
99
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
100
101
      break;
    case NVTE_Activation_Type::GEGLU:
102
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
103
104
      break;
    case NVTE_Activation_Type::SILU:
105
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
106
107
      break;
    case NVTE_Activation_Type::SWIGLU:
108
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
109
110
      break;
    case NVTE_Activation_Type::RELU:
111
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
112
113
      break;
    case NVTE_Activation_Type::REGLU:
114
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
115
116
      break;
    case NVTE_Activation_Type::QGELU:
117
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
118
119
      break;
    case NVTE_Activation_Type::QGEGLU:
120
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
121
122
      break;
    case NVTE_Activation_Type::SRELU:
123
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
124
125
      break;
    case NVTE_Activation_Type::SREGLU:
126
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
127
128
129
130
131
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
132

133
134
135
  return ffi_with_cuda_error_check();
}

136
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
137
138
139
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
140
                                  .Arg<Buffer_Type>()      // scale
141
                                  .Ret<Buffer_Type>()      // output
142
143
144
145
146
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<int64_t>("act_enum")
147
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
148
                                  .Attr<bool>("is_2x"),
149
                              FFI_CudaGraph_Traits);
150

151
152
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
153
                                                   JAXX_Scaling_Mode scaling_mode, bool is_2x) {
154
155
156
157
158
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
159

160
161
162
163
164
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

165
166
167
168
169
170
171
172
173
174
175
176
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
177
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
178
179
180
181
182
183
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
184

185
  if (is_2x) {
186
187
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
188
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
189
190
191
192
193
194
195

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
196

197
  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
198
199
200
201
202
203
204
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
205
  // For now, all dbias_dact(-s) have the same workspace size
206
207
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
208

209
210
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
211
212
}

213
214
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
215
216
217
                                  Result_Type output_buf, Result_Type colwise_output_buf,
                                  Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                                  Result_Type amax_buf, Result_Type dbias_buf,
218
219
                                  Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
                                  int64_t act_enum, bool is_2x, bool is_dbias) {
220
221
222
223
224
225
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
226
227
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
228

229
230
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
231

232
233
234
235
236
  NVTE_CHECK(
      scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
      "Current tensor scaling does not support fused operations yet. Please call this primitive "
      "in higher-precision then quantize with current scaling.");

237
  auto *output = output_buf->untyped_data();
238
  auto *colwise_output = colwise_output_buf->untyped_data();
239
240
241
242
243
244
245
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
246
247
  // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
  auto act_len = act_input_dims[act_input_dims.size() - 2];
248
249
250
251
252
  NVTE_CHECK(act_len == 1 || act_len == 2,
             "The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
             "gated activation, got ",
             act_len);

253
254
255
256
257
258
259
260
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = input_dims.back();

  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};
  auto output_trans_shape = std::vector<size_t>{n * act_len, m};
  auto dbias_shape = std::vector<size_t>{n * act_len};
261
262
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

263
264
265
266
  auto input_tensor =
      TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
  auto act_input_tensor = TensorWrapper(
      act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
267
268

  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
269
270
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
  if (is_fp8_dtype(out_dtype)) {
271
    if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
272
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
273
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
274
      nvte_memset(amax, 0, sizeof(float), stream);
275
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
276
277
278
279
280
281
282
283
284
285
286
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
287
    }
288
289
  }

290
  if (is_2x) {
291
292
293
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
294
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
295
296
297

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
298
299
300
301
      auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? scale_inv_buf
                          : colwise_scale_inv_buf;
      if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
302
303
304
305
306
307
308
309
310
311
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
312
    }
313
  }
314

315
316
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
317

318
  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
319
  NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
320
321
  NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
             "TE/common does not support delayed scaling for 2x with gated activations.");
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
390
  }
391

392
393
394
  return ffi_with_cuda_error_check();
}

395
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
396
397
398
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
399
                                  .Arg<Buffer_Type>()      // act input
400
401
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
402
403
404
405
406
407
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
408
409
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("act_enum")
410
                                  .Attr<bool>("is_2x")
411
                                  .Attr<bool>("is_dbias"),
412
                              FFI_CudaGraph_Traits);
413
414
}  // namespace jax
}  // namespace transformer_engine