common.h 13.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
17
18
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
yuguo's avatar
yuguo committed
19
20
#include <cuda_runtime.h>
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
21
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
22
#include <cuda.h>
23
#include <cudnn.h>
yuguo's avatar
yuguo committed
24
25
26
27
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
Tim Moon's avatar
Tim Moon committed
28
29
30
31
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
32
#include <transformer_engine/cast_transpose_noop.h>
33
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
34
#include <transformer_engine/fused_attn.h>
35
#include <transformer_engine/fused_rope.h>
Tim Moon's avatar
Tim Moon committed
36
#include <transformer_engine/gemm.h>
37
#include <transformer_engine/multi_stream.h>
38
#include <transformer_engine/multi_tensor.h>
39
#include <transformer_engine/normalization.h>
40
#include <transformer_engine/padding.h>
41
#include <transformer_engine/permutation.h>
42
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
43
#include <transformer_engine/softmax.h>
44
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
45
46
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
47
48

#include <ATen/cuda/CUDAGraphsUtils.cuh>
49
#include <cassert>
50
51
#include <cstring>
#include <iostream>
52
53
#include <string>
#include <sstream>
54
#include <memory>
55
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
56
57
#include <vector>

58
#include "c10/util/ArrayRef.h"
59
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
60

61
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
62

63
64
65
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;

66
67
68
69
70
71
72
73
74
75
76
77
inline int blockwise_fp8_block_len() {
  const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
  if (env == nullptr || env[0] == '\0') {
    return 128;
  }
  int value;
  std::istringstream iss(env);
  iss >> value;
  NVTE_CHECK(iss, "Invalid environment variable value");
  return value;
}

Przemek Tredak's avatar
Przemek Tredak committed
78
79
80
81
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
82
83
84
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
85
86
87
88
89
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
90
91
92
93
94
95
96
97
98
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
99
100
101
102
103
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
104
105
106
107
108
109
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
110
111
};

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

  virtual std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8CurrentScalingQuantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;
  bool with_amax_reduction;
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  bool force_pow_2_scales = false;
  float amax_epsilon = 0.0;

  explicit Float8CurrentScalingQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

181
182
183
184
185
186
187
188
189
190
191
192
193
194
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8BlockQuantizer : public Quantizer {
 public:
  // Which float8 type is used for q data.
  DType dtype;
  // Options about how to quantize the tensor
  // Quantization scales are rounded down to powers of 2.
  bool force_pow_2_scales = false;
  // Amax within quantization tile has a floor of epsilon.
  float amax_epsilon = 0.0;
195
196
  // Whether quantized tensor will be used in an all-gather
  bool all_gather_usage = false;
197
198

 private:
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  int block_scaling_dim = 2;

 public:
  // Initializes from a python handle to a Float8BlockQuantizer
  explicit Float8BlockQuantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override {
    return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
  }

  // Gets rowwise and columnwise_data from tensor and sets them on wrapper
  void set_quantization_params(TensorWrapper* tensor) const override;

  // Create a python Float8BlockQuantized tensor and C++ wrapper
  // for the tensor. Should set quantized data, scales for rowwise
  // and optionally columnwise usage.
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(at::Tensor t);
Przemek Tredak's avatar
Przemek Tredak committed
238
239

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
240
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
241

242
inline size_t typeToNumBits(transformer_engine::DType t) {
243
244
  switch (t) {
    case transformer_engine::DType::kInt64:
245
      return 64;
246
247
    case transformer_engine::DType::kInt32:
    case transformer_engine::DType::kFloat32:
248
      return 32;
249
250
251
    case transformer_engine::DType::kInt16:
    case transformer_engine::DType::kFloat16:
    case transformer_engine::DType::kBFloat16:
252
      return 16;
253
254
255
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
256
    case transformer_engine::DType::kInt8:
257
258
259
      return 8;
    case transformer_engine::DType::kFloat4E2M1:
      return 4;
260
261
262
263
264
    default:
      NVTE_ERROR("Invalid type");
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
265
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
266
  switch (t) {
267
268
    case transformer_engine::DType::kInt8:
      return torch::kInt8;
269
270
    case transformer_engine::DType::kInt16:
      return torch::kInt16;
271
272
273
274
275
276
277
278
279
280
281
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
282
      return at::kByte;
283
    case transformer_engine::DType::kFloat8E4M3:
284
      return at::kFloat8_e4m3fn;
285
    case transformer_engine::DType::kFloat8E5M2:
286
      return at::kFloat8_e5m2;
287
288
289
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
290
291
292
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
293
  switch (t) {
294
295
296
297
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
298
299
300
301
302
303
304
305
306
307
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
308
309
    case torch::kInt16:
      return transformer_engine::DType::kInt16;
310
311
312
313
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
314
315
    case torch::kInt8:
      return transformer_engine::DType::kInt8;
316
    default:
317
      std::cout << "Type: " << static_cast<int>(t) << std::endl;
318
319
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
320
321
322
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
323
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
324
325
326
327
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
328
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
329

330
331
332
333
334
335
336
337
338
339
340
341
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
342
343
344

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
345
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
346
347
348

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

349
350
351
352
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
           std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);

353
354
355
356
357
358
359
360
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
361

362
363
364
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
365

366
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
367
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
368

369
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
370
371
                         bool init_to_zeros = false);

372
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
373

374
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
375

376
void* getDataPtr(at::Tensor tensor, int offset = 0);
377

378
379
380
381
std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

382
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
429
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_