common.h 15.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
17
18
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
Tim Moon's avatar
Tim Moon committed
19
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
20
21
#include <cuda.h>
#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
22
#include <cuda_runtime.h>
23
#include <cudnn.h>
Tim Moon's avatar
Tim Moon committed
24
25
26
27
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
28
#include <transformer_engine/cast_transpose_noop.h>
29
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
30
#include <transformer_engine/fused_attn.h>
31
#include <transformer_engine/fused_rope.h>
32
#include <transformer_engine/fused_router.h>
Tim Moon's avatar
Tim Moon committed
33
#include <transformer_engine/gemm.h>
34
#include <transformer_engine/multi_stream.h>
35
#include <transformer_engine/multi_tensor.h>
36
#include <transformer_engine/normalization.h>
37
#include <transformer_engine/padding.h>
38
#include <transformer_engine/permutation.h>
39
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
40
#include <transformer_engine/softmax.h>
41
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
42
43
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
44
45

#include <ATen/cuda/CUDAGraphsUtils.cuh>
46
#include <cassert>
47
48
49
#include <cstring>
#include <iostream>
#include <memory>
50
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
51
52
#include <vector>

53
#include "c10/util/ArrayRef.h"
54
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
55

56
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
57

58
59
60
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;

Przemek Tredak's avatar
Przemek Tredak committed
61
62
63
64
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
65
66
67
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
68
69
70
71
72
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
73
74
75
76
77
78
79
80
81
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
82
83
84
85
86
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
87
88
89
90
91
92
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
93
94
};

95
96
97
98
99
100
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  /*! @brief Construct a tensor with uninitialized data */
  virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                             DType dtype) const = 0;

  /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
   *
   * The PyTorch tensor's attributes are modified to match the
   * quantizer's configuration.
   */
  virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
      py::object tensor) const = 0;

  /*! @brief Convert to a quantized data format */
  virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
                        const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

136
137
138
139
140
141
142
143
144
145
146
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     at::Tensor data) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

162
163
164
165
166
167
168
169
170
171
172
173
174
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     std::optional<at::Tensor> data,
                                                     std::optional<at::Tensor> transpose,
                                                     std::optional<at::Tensor> scale_inv) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
};

class Float8CurrentScalingQuantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0;

  explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

194
195
196
197
198
199
200
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
201
202
203
204
205
206
207
208
209
210
211
};

class Float8BlockQuantizer : public Quantizer {
 public:
  // Which float8 type is used for q data.
  DType dtype;
  // Options about how to quantize the tensor
  // Quantization scales are rounded down to powers of 2.
  bool force_pow_2_scales = false;
  // Amax within quantization tile has a floor of epsilon.
  float amax_epsilon = 0.0;
212
213
  // Whether quantized tensor will be used in an all-gather
  bool all_gather_usage = false;
214
215

 private:
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
  int block_scaling_dim = 2;

 public:
  // Initializes from a python handle to a Float8BlockQuantizer
  explicit Float8BlockQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override {
    return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
  }

  // Gets rowwise and columnwise_data from tensor and sets them on wrapper
  void set_quantization_params(TensorWrapper* tensor) const override;

  // Create a python Float8BlockQuantized tensor and C++ wrapper
  // for the tensor. Should set quantized data, scales for rowwise
  // and optionally columnwise usage.
232
233
234
235
236
237
238
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
239
240

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
241
242
243
244
245
246
247
248
249
250
251
252
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

253
254
255
256
257
258
259
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
260
261

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
262
263
264
265
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

266
std::vector<size_t> getTensorShape(const at::Tensor& t);
Przemek Tredak's avatar
Przemek Tredak committed
267
268

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
269
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
270

271
inline size_t typeToNumBits(transformer_engine::DType t) {
272
273
  switch (t) {
    case transformer_engine::DType::kInt64:
274
      return 64;
275
276
    case transformer_engine::DType::kInt32:
    case transformer_engine::DType::kFloat32:
277
      return 32;
278
279
280
    case transformer_engine::DType::kInt16:
    case transformer_engine::DType::kFloat16:
    case transformer_engine::DType::kBFloat16:
281
      return 16;
282
283
284
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
285
286
287
      return 8;
    case transformer_engine::DType::kFloat4E2M1:
      return 4;
288
289
290
291
292
    default:
      NVTE_ERROR("Invalid type");
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
293
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
294
  switch (t) {
295
296
    case transformer_engine::DType::kInt16:
      return torch::kInt16;
297
298
299
300
301
302
303
304
305
306
307
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
308
      return at::kByte;
309
    case transformer_engine::DType::kFloat8E4M3:
310
      return at::kFloat8_e4m3fn;
311
    case transformer_engine::DType::kFloat8E5M2:
312
      return at::kFloat8_e5m2;
313
314
315
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
316
317
318
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
319
  switch (t) {
320
321
322
323
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
324
325
326
327
328
329
330
331
332
333
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
334
335
    case torch::kInt16:
      return transformer_engine::DType::kInt16;
336
337
338
339
340
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
    default:
341
      std::cout << "Type: " << static_cast<int>(t) << std::endl;
342
343
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
344
345
346
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
347
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
348
349
350
351
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
352
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
353

354
355
356
357
358
359
360
361
362
363
364
365
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
366
367
368

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
369
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
370
371
372

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

373
374
375
376
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
           std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);

377
378
379
380
381
382
383
384
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
385

386
387
388
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
389

390
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
391
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
392

393
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
394
395
                         bool init_to_zeros = false);

396
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
397

398
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
399

400
void* getDataPtr(at::Tensor tensor, int offset = 0);
401

402
403
404
405
std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

406
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
453
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_