"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "d64bf1646c19123bf236d5ba1c02b75ee3d68306"
activation.cpp 18.8 KB
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "transformer_engine/activation.h"
7

8
#include "extensions.h"
9
#include "transformer_engine/transpose.h"
10
#include "xla/ffi/api/c_api.h"
11
12
13
14

namespace transformer_engine {
namespace jax {

15
// TODO: We won't need this function anymore when we move to the new XLA custom calls
16
17
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
  switch (activation_enum) {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    case NVTE_Activation_Type::GELU:
      return 1;
    case NVTE_Activation_Type::GEGLU:
      return 2;
    case NVTE_Activation_Type::SILU:
      return 1;
    case NVTE_Activation_Type::SWIGLU:
      return 2;
    case NVTE_Activation_Type::RELU:
      return 1;
    case NVTE_Activation_Type::REGLU:
      return 2;
    case NVTE_Activation_Type::QGELU:
      return 1;
    case NVTE_Activation_Type::QGEGLU:
      return 2;
    case NVTE_Activation_Type::SRELU:
      return 1;
    case NVTE_Activation_Type::SREGLU:
      return 2;
38
39
40
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
41
      return -1;
42
43
44
45
  }
}

void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
46
               cudaStream_t stream, float *scale_inverse, float *amax, void *output,
47
               NVTE_Activation_Type act_enum, size_t act_len) {
48
49
50
51
52
53
  auto input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n};
  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
                                     scale, scale_inverse);
  switch (act_enum) {
54
    case NVTE_Activation_Type::GELU:
55
56
      nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
      break;
57
    case NVTE_Activation_Type::GEGLU:
58
59
      nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
      break;
60
    case NVTE_Activation_Type::SILU:
61
62
      nvte_silu(input_tensor.data(), output_tensor.data(), stream);
      break;
63
    case NVTE_Activation_Type::SWIGLU:
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
      nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_relu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
88
89
90
}

void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
91
92
  auto *input = buffers[0];
  auto *output = buffers[1];
93

94
95
96
97
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
98
  auto act_len = get_activation_len(act_enum);
99

100
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
101
            act_enum, act_len);
102
103
}

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
                    int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
  auto n = input_dims.back();
  auto act_len = input_dims.end()[-2];
  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);

  ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type,
            act_len);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<int64_t>("act_enum"));

131
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
132
133
134
135
136
137
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  float *amax_out = reinterpret_cast<float *>(buffers[5]);
138
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
139

140
141
142
143
144
145
146
147
148
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
149
  auto act_len = get_activation_len(act_enum);
150

151
  ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
152
            act_enum, act_len);
153
154
155
}

void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
156
157
158
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  auto *output = buffers[2];
159

160
161
162
163
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
164

165
166
167
168
  auto act_len = get_activation_len(act_enum);
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};
169

170
171
172
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
      nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
209
210
}

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf,
                     Result_Type output_buf, int64_t act_enum) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *act_input = act_input_buf.untyped_data();
  auto *output = output_buf->untyped_data();

  auto act_input_dims = act_input_buf.dimensions();
  auto m =
      std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>());
  auto n = act_input_dims.back();
  auto act_len = act_input_dims.end()[-2];

  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n * act_len};
  auto output_shape = std::vector<size_t>{m, n * act_len};

  auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast<DType>(in_dtype));
  auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype));

  auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
  switch (act_type) {
    case NVTE_Activation_Type::GELU:
      nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // act_input
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<int64_t>("act_enum"));

281
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
282
283
284
285
286
287
                                                        DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
288

289
290
291
292
293
  auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
  auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
  auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
  auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
294

295
  TensorWrapper dummy_workspace;
296

297
298
299
300
  // For now, all dbias_dact(-s) have the same workspace size
  nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(),
                                  dbias_tensor.data(), dummy_workspace.data(), nullptr);
301

302
303
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
304
305
306
}

void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
307
308
309
310
311
312
313
314
315
316
317
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  auto *dbias = buffers[7];
  float *amax_out = reinterpret_cast<float *>(buffers[8]);
  void *workspace_ptr = buffers[9];
318

319
  const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
320
321
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
322
323
324
325
326
327
328
329
330
331
332
333
334
335
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
  auto input_shape = std::vector<size_t>{m, n};
  auto act_input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
336

337
338
339
340
341
342
343
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto output_trans_tensor =
      TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
344

345
  auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
346

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
  switch (act_enum) {
    case NVTE_Activation_Type::GELU:
      nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::SILU:
      nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::RELU:
      nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
                                      output_tensor.data(), output_trans_tensor.data(),
                                      dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::QGELU:
      nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
                                       output_tensor.data(), output_trans_tensor.data(),
                                       dbias_tensor.data(), workspace.data(), stream);
      break;
    case NVTE_Activation_Type::SRELU:
      nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
                                       output_tensor.data(), output_trans_tensor.data(),
                                       dbias_tensor.data(), workspace.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
377
378
379
}

void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
380
381
382
383
384
385
386
387
388
                              size_t opaque_len) {
  auto *input = buffers[0];
  auto *act_input = buffers[1];
  float *amax = reinterpret_cast<float *>(buffers[2]);
  float *scale = reinterpret_cast<float *>(buffers[3]);
  float *scale_inv = reinterpret_cast<float *>(buffers[4]);
  auto *output = buffers[5];
  auto *output_trans = buffers[6];
  float *amax_out = reinterpret_cast<float *>(buffers[7]);
389

390
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
391
392
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
393
394
395
396
397
398
399
400
401
402
403
404
405
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
  ;
  auto input_shape = desc.shape.to_vector();
  auto act_input_shape = std::vector<size_t>{m, n * 2};
  auto output_shape = std::vector<size_t>{m, n * 2};
  auto output_trans_shape = std::vector<size_t>{n * 2, m};
406

407
408
409
410
411
412
  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto output_trans_tensor =
      TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  switch (act_enum) {
    case NVTE_Activation_Type::GEGLU:
      nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SWIGLU:
      nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::REGLU:
      nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
                                 output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::QGEGLU:
      nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    case NVTE_Activation_Type::SREGLU:
      nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
                                  output_tensor.data(), output_trans_tensor.data(), stream);
      break;
    default:
      NVTE_ERROR("Unsupported ActivationEnum");
      break;
  }
439
440
441
442
}

}  // namespace jax
}  // namespace transformer_engine