activation.cpp 14.3 KB
Newer Older
1
2
3
4
5
6
7
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/activation.h"
8
9

#include "jax/csrc/extensions.h"
10
11
12
13
14
15
16
#include "transformer_engine/transpose.h"

namespace transformer_engine {
namespace jax {

size_t get_activation_len(NVTE_Activation_Type activation_enum) {
  switch (activation_enum) {
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    case NVTE_Activation_Type::GELU:
      return 1;
    case NVTE_Activation_Type::GEGLU:
      return 2;
    case NVTE_Activation_Type::SILU:
      return 1;
    case NVTE_Activation_Type::SWIGLU:
      return 2;
    case NVTE_Activation_Type::RELU:
      return 1;
    case NVTE_Activation_Type::REGLU:
      return 2;
    case NVTE_Activation_Type::QGELU:
      return 1;
    case NVTE_Activation_Type::QGEGLU:
      return 2;
    case NVTE_Activation_Type::SRELU:
      return 1;
    case NVTE_Activation_Type::SREGLU:
      return 2;
37
38
39
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
40
      return -1;
41
42
43
44
  }
}

void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
45
46
47
48
49
50
51
52
53
               cudaStream_t stream, float *scale_inverse, float *amax, void *output,
               NVTE_Activation_Type act_enum) {
  auto act_len = get_activation_len(act_enum);
  auto input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n};
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                     scale, scale_inverse);
  switch (act_enum) {
54
    case NVTE_Activation_Type::GELU:
55
56
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
      break;
57
    case NVTE_Activation_Type::GEGLU:
58
59
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
      break;
60
    case NVTE_Activation_Type::SILU:
61
62
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
      break;
63
    case NVTE_Activation_Type::SWIGLU:
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
88
89
90
}

void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
91
92
  auto *input = buffers[0];
  auto *output = buffers[1];
93

94
95
96
97
98
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
99

100
101
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
            act_enum);
102
103
104
}

void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
105
106
107
108
109
110
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  float *amax_out = reinterpret_cast<float *>(buffers[5]);
111
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
112

113
114
115
116
117
118
119
120
121
122
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
123

124
125
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
            act_enum);
126
127
128
}

void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
129
130
131
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  auto *output = buffers[2];
132

133
134
135
136
137
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
138

139
140
141
142
  auto act_len = get_activation_len(act_enum);
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};
143

144
145
146
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
      nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
183
184
185
}

pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
186
187
188
189
190
191
                                                        DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
192

193
194
195
196
197
  auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
  auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
  auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
  auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
198

199
  TensorWrapper dummy_workspace;
200

201
202
203
204
  // For now, all dbias_dact(-s) have the same workspace size
  nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(),
                                  dbias_tensor.data(), dummy_workspace.data(), nullptr);
205

206
207
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
208
209
210
}

void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
211
212
213
214
215
216
217
218
219
220
221
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  auto *dbias = buffers[7];
  float *amax_out = reinterpret_cast<float *>(buffers[8]);
  void *workspace_ptr = buffers[9];
222

223
  const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
224
225
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
226
227
228
229
230
231
232
233
234
235
236
237
238
239
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
240

241
242
243
244
245
246
247
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto output_trans_tensor =
      TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
248

249
  auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
      nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                       output_tensor.data(), output_trans_tensor.data(),
                                       dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                       output_tensor.data(), output_trans_tensor.data(),
                                       dbias_tensor.data(), workspace.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
281
282
283
}

void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
284
285
286
287
288
289
290
291
292
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  float *amax_out = reinterpret_cast<float *>(buffers[7]);
293

294
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
295
296
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
297
298
299
300
301
302
303
304
305
306
307
308
309
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
  auto input_shape = desc.shape.to_vector();
  auto act_input_shape = std::vector<size_t>{m, n * 2};
  auto output_shape = std::vector<size_t>{m, n * 2};
  auto output_trans_shape = std::vector<size_t>{n * 2, m};
310

311
312
313
314
315
316
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto output_trans_tensor =
      TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
317

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
  switch (act_enum) {
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
343
344
345
346
}

}  // namespace jax
}  // namespace transformer_engine