common.h 17.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
17
18
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
yuguo's avatar
yuguo committed
19
20
#include <cuda_runtime.h>
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
21
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
22
#include <cuda.h>
23
#include <cudnn.h>
yuguo's avatar
yuguo committed
24
25
26
27
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
Tim Moon's avatar
Tim Moon committed
28
29
30
31
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
32
#include <transformer_engine/cast_transpose_noop.h>
33
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
34
#include <transformer_engine/fused_attn.h>
35
#include <transformer_engine/fused_rope.h>
36
#include <transformer_engine/fused_router.h>
Tim Moon's avatar
Tim Moon committed
37
#include <transformer_engine/gemm.h>
38
#include <transformer_engine/multi_stream.h>
39
#include <transformer_engine/multi_tensor.h>
40
#include <transformer_engine/normalization.h>
41
#include <transformer_engine/padding.h>
42
#include <transformer_engine/permutation.h>
43
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
44
#include <transformer_engine/softmax.h>
45
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
46
47
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
48
49

#include <ATen/cuda/CUDAGraphsUtils.cuh>
50
#include <cassert>
51
52
#include <cstring>
#include <iostream>
53
54
#include <string>
#include <sstream>
55
#include <memory>
56
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
57
58
#include <vector>

59
#include "c10/util/ArrayRef.h"
60
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
61

62
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
63

64
65
66
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;

67
68
69
70
71
72
73
74
75
76
77
78
inline int blockwise_fp8_block_len() {
  const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
  if (env == nullptr || env[0] == '\0') {
    return 128;
  }
  int value;
  std::istringstream iss(env);
  iss >> value;
  NVTE_CHECK(iss, "Invalid environment variable value");
  return value;
}

Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
83
84
85
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
89
90
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
91
92
93
94
95
96
97
98
99
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
100
101
102
103
104
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
105
106
107
108
109
110
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
111
112
};

113
114
115
116
117
118
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  /*! @brief Construct a tensor with uninitialized data */
  virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                             DType dtype) const = 0;

  /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
   *
   * The PyTorch tensor's attributes are modified to match the
   * quantizer's configuration.
   */
  virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
      py::object tensor) const = 0;

  /*! @brief Convert to a quantized data format */
  virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
                        const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

154
155
156
157
158
159
160
161
162
163
164
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     at::Tensor data) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

180
181
182
183
184
185
186
187
188
189
190
191
192
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  /*! @brief Construct a tensor with pre-initialized data */
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
                                                     std::optional<at::Tensor> data,
                                                     std::optional<at::Tensor> transpose,
                                                     std::optional<at::Tensor> scale_inv) const;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
};

class Float8CurrentScalingQuantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0;

  explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

212
213
214
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

215
216
217
218
219
220
221
222
  /*! @brief Construct a high precision tensor giving it this quantizer's amax

  Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
        a kernel computing the amax, which might expect the amax to be initialized to zero
  */
  std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape,
                                                                  DType dtype);

223
224
225
226
  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
227
228
229
230
231
232
233
234

  /*! @brief Convert to a quantized data format avoiding amax computation */
  void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
                          const std::optional<TensorWrapper>& noop_flag = std::nullopt);

 private:
  void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
                     const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
235
236
237
238
239
240
241
242
243
244
245
};

class Float8BlockQuantizer : public Quantizer {
 public:
  // Which float8 type is used for q data.
  DType dtype;
  // Options about how to quantize the tensor
  // Quantization scales are rounded down to powers of 2.
  bool force_pow_2_scales = false;
  // Amax within quantization tile has a floor of epsilon.
  float amax_epsilon = 0.0;
246
247
  // Whether quantized tensor will be used in an all-gather
  bool all_gather_usage = false;
248
249

 private:
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  int block_scaling_dim = 2;

 public:
  // Initializes from a python handle to a Float8BlockQuantizer
  explicit Float8BlockQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override {
    return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
  }

  // Gets rowwise and columnwise_data from tensor and sets them on wrapper
  void set_quantization_params(TensorWrapper* tensor) const override;

  // Create a python Float8BlockQuantized tensor and C++ wrapper
  // for the tensor. Should set quantized data, scales for rowwise
  // and optionally columnwise usage.
266
267
268
269
270
271
272
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
273
274

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
275
276
277
278
279
280
281
282
283
284
285
286
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

287
288
289
290
291
292
293
  std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
                                                     DType dtype) const override;

  std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

  void quantize(const TensorWrapper& input, TensorWrapper& out,
                const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
294
295

  std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
296
297
298
299
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

300
std::vector<size_t> getTensorShape(const at::Tensor& t);
Przemek Tredak's avatar
Przemek Tredak committed
301
302

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
303
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
304

305
inline size_t typeToNumBits(transformer_engine::DType t) {
306
307
  switch (t) {
    case transformer_engine::DType::kInt64:
308
      return 64;
309
310
    case transformer_engine::DType::kInt32:
    case transformer_engine::DType::kFloat32:
311
      return 32;
312
313
314
    case transformer_engine::DType::kInt16:
    case transformer_engine::DType::kFloat16:
    case transformer_engine::DType::kBFloat16:
315
      return 16;
316
317
318
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
319
    case transformer_engine::DType::kInt8:
320
      return 8;
321
    #if FP4_TYPE_SUPPORTED
322
323
    case transformer_engine::DType::kFloat4E2M1:
      return 4;
324
    #endif
325
326
327
328
329
    default:
      NVTE_ERROR("Invalid type");
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
330
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
331
  switch (t) {
332
333
    case transformer_engine::DType::kInt8:
      return torch::kInt8;
334
335
    case transformer_engine::DType::kInt16:
      return torch::kInt16;
336
337
338
339
340
341
342
343
344
345
346
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
347
      return at::kByte;
348
    case transformer_engine::DType::kFloat8E4M3:
349
      return at::kFloat8_e4m3fn;
350
    case transformer_engine::DType::kFloat8E5M2:
351
      return at::kFloat8_e5m2;
352
353
354
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
355
356
357
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
358
  switch (t) {
359
360
361
362
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
363
364
365
366
367
368
369
370
371
372
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
373
374
    case torch::kInt16:
      return transformer_engine::DType::kInt16;
375
376
377
378
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
379
380
    case torch::kInt8:
      return transformer_engine::DType::kInt8;
381
    default:
382
      std::cout << "Type: " << static_cast<int>(t) << std::endl;
383
384
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
385
386
387
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
388
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
389
390
391
392
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
393
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
394

395
396
397
398
399
400
401
402
403
404
405
406
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
407
408
409

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
410
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
411
412
413

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

414
415
416
417
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
           std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);

418
419
420
421
422
423
424
425
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
426

427
428
429
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
430

431
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
432
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
433

434
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
435
436
                         bool init_to_zeros = false);

437
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
438

439
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
440

441
void* getDataPtr(at::Tensor tensor, int offset = 0);
442

443
444
std::vector<size_t> convertShape(const NVTEShape& shape);

445
size_t roundup(const size_t value, const size_t multiple);
446

447
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
494
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_