common.h 13.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
17
18
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
yuguo's avatar
yuguo committed
19
20
#include <cuda_runtime.h>
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
21
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
22
#include <cuda.h>
23
#include <cudnn.h>
yuguo's avatar
yuguo committed
24
25
26
27
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
Tim Moon's avatar
Tim Moon committed
28
29
30
31
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
32
#include <transformer_engine/cast_transpose_noop.h>
33
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
34
#include <transformer_engine/fused_attn.h>
35
#include <transformer_engine/fused_rope.h>
Tim Moon's avatar
Tim Moon committed
36
#include <transformer_engine/gemm.h>
37
#include <transformer_engine/multi_tensor.h>
38
#include <transformer_engine/normalization.h>
39
#include <transformer_engine/padding.h>
40
#include <transformer_engine/permutation.h>
41
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
42
#include <transformer_engine/softmax.h>
43
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
44
45
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
46
47

#include <ATen/cuda/CUDAGraphsUtils.cuh>
48
#include <cassert>
49
50
#include <cstring>
#include <iostream>
51
52
#include <string>
#include <sstream>
53
#include <memory>
54
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
55
56
#include <vector>

57
#include "c10/util/ArrayRef.h"
58
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
59

60
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
61

62
63
64
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;

65
66
67
68
69
70
71
72
73
74
75
76
inline int blockwise_fp8_block_len() {
  const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
  if (env == nullptr || env[0] == '\0') {
    return 128;
  }
  int value;
  std::istringstream iss(env);
  iss >> value;
  NVTE_CHECK(iss, "Invalid environment variable value");
  return value;
}

Przemek Tredak's avatar
Przemek Tredak committed
77
78
79
80
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
81
82
83
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
84
85
86
87
88
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
89
90
91
92
93
94
95
96
97
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
98
99
100
101
102
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
103
104
105
106
107
108
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
109
110
};

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

  virtual std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8CurrentScalingQuantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0;

  explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

180
181
182
183
184
185
186
187
188
189
190
191
192
193
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8BlockQuantizer : public Quantizer {
 public:
  // Which float8 type is used for q data.
  DType dtype;
  // Options about how to quantize the tensor
  // Quantization scales are rounded down to powers of 2.
  bool force_pow_2_scales = false;
  // Amax within quantization tile has a floor of epsilon.
  float amax_epsilon = 0.0;
194
195

 private:
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
  int block_scaling_dim = 2;

 public:
  // Initializes from a python handle to a Float8BlockQuantizer
  explicit Float8BlockQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override {
    return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
  }

  // Gets rowwise and columnwise_data from tensor and sets them on wrapper
  void set_quantization_params(TensorWrapper* tensor) const override;

  // Create a python Float8BlockQuantized tensor and C++ wrapper
  // for the tensor. Should set quantized data, scales for rowwise
  // and optionally columnwise usage.
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(at::Tensor t);
Przemek Tredak's avatar
Przemek Tredak committed
235
236

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
237
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
inline size_t typeToSize(transformer_engine::DType t) {
  switch (t) {
    case transformer_engine::DType::kInt64:
      return 8;
    case transformer_engine::DType::kInt32:
    case transformer_engine::DType::kFloat32:
      return 4;
    case transformer_engine::DType::kInt16:
    case transformer_engine::DType::kFloat16:
    case transformer_engine::DType::kBFloat16:
      return 2;
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
      return 1;
    default:
      NVTE_ERROR("Invalid type");
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
259
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
260
  switch (t) {
261
262
    case transformer_engine::DType::kInt16:
      return torch::kInt16;
263
264
265
266
267
268
269
270
271
272
273
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
274
      return at::kByte;
275
    case transformer_engine::DType::kFloat8E4M3:
276
      return at::kFloat8_e4m3fn;
277
    case transformer_engine::DType::kFloat8E5M2:
278
      return at::kFloat8_e5m2;
279
280
281
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
282
283
284
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
285
  switch (t) {
286
287
288
289
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
290
291
292
293
294
295
296
297
298
299
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
300
301
    case torch::kInt16:
      return transformer_engine::DType::kInt16;
302
303
304
305
306
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
    default:
307
      std::cout << "Type: " << static_cast<int>(t) << std::endl;
308
309
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
310
311
312
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
313
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
314
315
316
317
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
318
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
319

320
321
322
323
324
325
326
327
328
329
330
331
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
332
333
334

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
335
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
336
337
338

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

339
340
341
342
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
           std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);

343
344
345
346
347
348
349
350
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
351

352
353
354
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
355

356
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
357
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
358

359
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
360
361
                         bool init_to_zeros = false);

362
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
363

364
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
365

366
void* getDataPtr(at::Tensor tensor, int offset = 0);
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
418
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_