activation.cpp 17.7 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
9
#include <cuda_runtime.h>

10
#include "extensions.h"
11
#include "transformer_engine/cast.h"
12
#include "xla/ffi/api/c_api.h"
13

14
15
16
17
18
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
  return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
         act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
         act_type == NVTE_Activation_Type::SREGLU;
19
}
20
}  // namespace
21

22
23
namespace transformer_engine {
namespace jax {
24

25
26
27
28
29
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                    Result_Type output_buf, Result_Type colwise_output_buf,
                    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                    Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum,
                    bool is_2x_int) {
30
31
32
33
34
35
36
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());

  auto *output = output_buf->untyped_data();
37
38
  auto *colwise_output = colwise_output_buf->untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
39
40

  auto input_dims = input_buf.dimensions();
41
  auto m = product(input_dims, 0, input_dims.size() - 2);
42
43
  auto n = input_dims.back();
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
44
45
46
  auto act_len = input_dims[input_dims.size() - 2];
  auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
  auto is_2x = static_cast<bool>(is_2x_int);
47

48
49
50
51
52
53
54
55
56
57
58
59
60
  auto input_shape = std::vector<size_t>{m, act_len * n};
  auto output_shape = std::vector<size_t>{m, n};
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(scaling_mode);
  output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);

  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(
        scale_inv_buf->untyped_data(),
        convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
        std::vector<size_t>{
            product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
            scale_inv_buf->dimensions().back()});
61
  }
62

63
64
65
66
67
68
69
  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
    NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
    NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
    cudaMemsetAsync(amax, 0, sizeof(float), stream);
    output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
    output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
  }
70

71
72
73
74
75
76
77
78
79
  if (is_2x) {
    output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape);
    output_tensor.set_columnwise_scale_inv(
        colwise_scale_inv_buf->untyped_data(),
        convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
        std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
                                    colwise_scale_inv_buf->dimensions().size() - 1),
                            colwise_scale_inv_buf->dimensions().back()});
  }
80
81
82

  switch (act_type) {
    case NVTE_Activation_Type::GELU:
83
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
84
85
      break;
    case NVTE_Activation_Type::GEGLU:
86
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
87
88
      break;
    case NVTE_Activation_Type::SILU:
89
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
90
91
      break;
    case NVTE_Activation_Type::SWIGLU:
92
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
93
94
      break;
    case NVTE_Activation_Type::RELU:
95
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
96
97
      break;
    case NVTE_Activation_Type::REGLU:
98
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
99
100
      break;
    case NVTE_Activation_Type::QGELU:
101
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
102
103
      break;
    case NVTE_Activation_Type::QGEGLU:
104
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
105
106
      break;
    case NVTE_Activation_Type::SRELU:
107
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
108
109
      break;
    case NVTE_Activation_Type::SREGLU:
110
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
111
112
113
114
115
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
116

117
118
119
  return ffi_with_cuda_error_check();
}

120
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
121
122
123
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
124
                                  .Arg<Buffer_Type>()      // scale
125
                                  .Ret<Buffer_Type>()      // output
126
127
128
129
130
131
132
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<int64_t>("act_enum")
                                  .Attr<int64_t>("scaling_mode")
                                  .Attr<bool>("is_2x"),
133
                              FFI_CudaGraph_Traits);
134

135
136
137
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType out_dtype,
                                                   int scaling_mode, bool is_2x) {
138
139
140
141
142
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
143

144
145
146
147
148
149
150
151
152
153
154
155
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias_dgelu chooses its internal impl based
  // on what pointers are allocated, e.g. whether to output with
  // column-wise data. However, we don't have access to any allocated
  // buffers in this function. We pass a dummy pointer as a
  // workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dact_input_tensor =
      TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
156
157
158
159
160
161
162
  auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode));
  output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                        std::vector<size_t>{1});
  }
163

164
165
166
167
168
169
170
171
172
173
  if (is_2x) {
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype,
                                      output_trans_shape);

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }
174

175
176
177
178
179
180
181
182
  if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

  TensorWrapper dummy_workspace;
183
  // For now, all dbias_dact(-s) have the same workspace size
184
185
  nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
186

187
188
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
189
190
}

191
192
193
194
195
196
197
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
                                  Buffer_Type act_input_buf, Buffer_Type scale_buf,
                                  Result_Type output_buf, Result_Type output_trans_buf,
                                  Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
                                  Result_Type amax_out_buf, Result_Type dbias_buf,
                                  Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
                                  bool is_dbias, int64_t act_enum) {
198
199
200
201
202
203
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
204
205
206

  auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);

207
208
209
210
211
212
213
214
215
216
217
  auto *output = output_buf->untyped_data();
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto act_input_dims = act_input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
  // n = ir_dz_shape[-1], ir_dz_shape == input_dims
  auto input_ranks = input_dims.size();
218
219
220
221
222
  auto act_input_ranks = act_input_dims.size();
  auto m = product(act_input_dims, 0, act_input_dims.size() - 1);
  // 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
  auto n = act_input_dims.back();
  auto input_shape = std::vector<size_t>{m, input_dims.back()};
223
224
  auto act_input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
225
  auto output_trans_shape = std::vector<size_t>{m, n};
226
227
228
229
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
  auto output_tensor = TensorWrapper(scaling_mode);
  output_tensor.set_rowwise_data(output, out_dtype, output_shape);
  if (is_fp8_dtype(out_dtype)) {
    output_tensor.set_rowwise_scale_inv(
        scale_inv_buf->untyped_data(),
        convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
        std::vector<size_t>{
            product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
            scale_inv_buf->dimensions().back()});

    if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
      float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
      float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
      NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
      NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
      cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
      output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
      output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
    }
250
251
  }

252
253
254
255
256
257
258
259
260
261
262
263
264
265
  if (is_2x) {
    output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);

    if (is_fp8_dtype(out_dtype)) {
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
      auto &colwise_scale_inv_buf =
          (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
      output_tensor.set_columnwise_scale_inv(
          colwise_scale_inv_buf->untyped_data(),
          convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
          std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
                                      colwise_scale_inv_buf->dimensions().size() - 1),
                              colwise_scale_inv_buf->dimensions().back()});
    }
266
  }
267

268
269
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
270
271

  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345

  // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
  NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!");
  NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x &&
               is_gated(act_type)),
             "TE/common does not support delayed scaling for 2x with gated activations.");

  if (is_dbias) {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), dbias_tensor.data(),
                                  workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                   output_tensor.data(), dbias_tensor.data(),
                                   workspace_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
        break;
    }
  } else {
    switch (act_type) {
      case NVTE_Activation_Type::GELU:
        nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SILU:
        nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::RELU:
        nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGELU:
        nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SRELU:
        nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::GEGLU:
        nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SWIGLU:
        nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::REGLU:
        nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::QGEGLU:
        nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      case NVTE_Activation_Type::SREGLU:
        nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
        break;
      default:
        NVTE_ERROR("Unsupported ActivationEnum");
        break;
    }
346
  }
347

348
349
350
  return ffi_with_cuda_error_check();
}

351
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
352
353
354
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
355
                                  .Arg<Buffer_Type>()      // act input
356
357
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
358
359
360
361
362
363
364
365
366
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<int64_t>("scaling_mode")
                                  .Attr<bool>("is_2x")
                                  .Attr<bool>("is_dbias")
367
368
                                  .Attr<int64_t>("act_enum"),
                              FFI_CudaGraph_Traits);
369
370
}  // namespace jax
}  // namespace transformer_engine