activation.cpp 19 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
namespace transformer_engine {
namespace jax {
16

17
18
19
20
21
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                    Result_Type output_buf, Result_Type colwise_output_buf,
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                    Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum,
                    bool is_2x_int) {
22
23
24
25
26
27
28
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
29
30
  auto *colwise_output = colwise_output_buf->untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
31
32

  auto input_dims = input_buf.dimensions();
33
  auto m = product(input_dims, 0, input_dims.size() - 2);
34
35
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
36
37
38
  auto act_len = input_dims[input_dims.size() - 2];
  auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
  auto is_2x = static_cast<bool>(is_2x_int);
39
  auto flatten_axis = output_buf->dimensions().size() - 1;  // output does not have act axis
40

41
42
  auto input_shape = std::vector<size_t>{m, act_len * n};
  auto output_shape = std::vector<size_t>{m, n};
43
  auto output_trans_shape = std::vector<size_t>{n, m};
44
45
46
47
48
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(scaling_mode);
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);

  if (is_fp8_dtype(out_dtype)) {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
      cudaMemsetAsync(amax, 0, sizeof(float), stream);
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
    }
66
  }
67

68
  if (is_2x) {
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    auto &tmp_shape =
        (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
      auto &tmp_buf =
          (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
      if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
    }
89
  }
90
91
92

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
93
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
94
95
      break;
    case NVTE_Activation_Type::GEGLU:
96
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
97
98
      break;
    case NVTE_Activation_Type::SILU:
99
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
100
101
      break;
    case NVTE_Activation_Type::SWIGLU:
102
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
103
104
      break;
    case NVTE_Activation_Type::RELU:
105
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
106
107
      break;
    case NVTE_Activation_Type::REGLU:
108
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
109
110
      break;
    case NVTE_Activation_Type::QGELU:
111
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
112
113
      break;
    case NVTE_Activation_Type::QGEGLU:
114
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
115
116
      break;
    case NVTE_Activation_Type::SRELU:
117
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
118
119
      break;
    case NVTE_Activation_Type::SREGLU:
120
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
121
122
123
124
125
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
126

127
128
129
  return ffi_with_cuda_error_check();
}

130
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
131
132
133
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
134
                                  .Arg<Buffer_Type>()      // scale
135
                                  .Ret<Buffer_Type>()      // output
136
137
138
139
140
141
142
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<int64_t>("act_enum")
                                  .Attr<int64_t>("scaling_mode")
                                  .Attr<bool>("is_2x"),
143
                              FFI_CudaGraph_Traits);
144

145
146
147
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
                                                   int scaling_mode, bool is_2x) {
148
149
150
151
152
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
153

154
155
156
157
158
159
160
161
162
163
164
165
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
166
167
168
169
170
171
172
  auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode));
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
173

174
  if (is_2x) {
175
176
177
178
    auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
179
180
181
182
183
184
185

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
186

187
188
189
190
191
192
193
194
  if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
195
  // For now, all dbias_dact(-s) have the same workspace size
196
197
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
198

199
200
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
201
202
}

203
204
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
205
206
207
                                  Result_Type output_buf, Result_Type colwise_output_buf,
                                  Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                                  Result_Type amax_buf, Result_Type dbias_buf,
208
209
                                  Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
                                  bool is_dbias, int64_t act_enum) {
210
211
212
213
214
215
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
216
217
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
218
219

  auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
220
221
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  auto flatten_axis = output_buf->dimensions().size() - 2;  // output has act axis
222

223
  auto *output = output_buf->untyped_data();
224
  auto *colwise_output = colwise_output_buf->untyped_data();
225
226
227
228
229
230
231
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
232
233
234
235
236
237
238
239
240
241
242
243
  // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
  auto act_len = act_input_dims[act_input_dims.size() - 2];
  NVTE_CHECK(act_input_dims.back() == input_dims.back(),
             "Shape mismatch between activation input and gradient input");
  auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
  auto n = input_dims.back();

  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};
  auto output_trans_shape = std::vector<size_t>{n * act_len, m};
  auto dbias_shape = std::vector<size_t>{n * act_len};
244
245
246
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
247
248
249
250
251
252
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
  auto output_tensor = TensorWrapper(scaling_mode);
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
  if (is_fp8_dtype(out_dtype)) {
    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
253
254
      NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
      cudaMemsetAsync(amax, 0, sizeof(float), stream);
255
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
256
257
258
259
260
261
262
263
264
265
266
      output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
    } else {
      output_tensor.set_rowwise_scale_inv(
          scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
          std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                              product(scale_inv_buf->dimensions(), flatten_axis,
                                      scale_inv_buf->dimensions().size())});
267
    }
268
269
  }

270
  if (is_2x) {
271
272
273
    auto &tmp_shape =
        (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
    output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
274
275
276

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
277
278
279
280
281
282
283
284
285
286
287
288
289
      auto &tmp_buf =
          (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
      if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_columnwise_scale_inv(
            tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
            std::vector<size_t>{
                product(tmp_buf->dimensions(), 0, flatten_axis),
                product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
      }
290
    }
291
  }
292

293
294
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
295

296
  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
297
298
299
300
  NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
  NVTE_CHECK(
      !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
      "TE/common does not support delayed scaling for 2x with gated activations.");
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
369
  }
370

371
372
373
  return ffi_with_cuda_error_check();
}

374
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
375
376
377
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
378
                                  .Arg<Buffer_Type>()      // act input
379
380
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
381
382
383
384
385
386
387
388
389
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<int64_t>("scaling_mode")
                                  .Attr<bool>("is_2x")
                                  .Attr<bool>("is_dbias")
390
391
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);
392
393
}  // namespace jax
}  // namespace transformer_engine