common.h 9.81 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
17
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
18
19
#include <cuda.h>
#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
20
#include <cuda_runtime.h>
21
#include <cudnn.h>
Tim Moon's avatar
Tim Moon committed
22
23
24
25
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
26
#include <transformer_engine/cast_transpose_noop.h>
27
#include <transformer_engine/comm_gemm_overlap.h>
Tim Moon's avatar
Tim Moon committed
28
#include <transformer_engine/fused_attn.h>
29
#include <transformer_engine/fused_rope.h>
Tim Moon's avatar
Tim Moon committed
30
#include <transformer_engine/gemm.h>
31
#include <transformer_engine/normalization.h>
32
#include <transformer_engine/padding.h>
33
#include <transformer_engine/permutation.h>
34
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
35
#include <transformer_engine/softmax.h>
36
#include <transformer_engine/swizzle.h>
Tim Moon's avatar
Tim Moon committed
37
38
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
39
40

#include <ATen/cuda/CUDAGraphsUtils.cuh>
41
#include <cassert>
42
43
44
#include <cstring>
#include <iostream>
#include <memory>
45
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
46
47
#include <vector>

48
#include "c10/util/ArrayRef.h"
49
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
50

51
namespace transformer_engine::pytorch {
Przemek Tredak's avatar
Przemek Tredak committed
52
53
54
55
56

// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
57
58
59
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
60
61
62
63
64
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
65
66
67
68
69
70
71
72
73
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
78
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
79
80
81
82
83
84
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
85
86
};

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class Quantizer {
 public:
  virtual NVTEScalingMode get_scaling_mode() const = 0;

  virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

  virtual std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;

  virtual ~Quantizer() = default;

  bool rowwise_usage = true;
  bool columnwise_usage = true;
  bool internal = false;
  py::handle quantizer;

 protected:
  explicit Quantizer(const py::handle& quantizer);
};

class NoneQuantizer : public Quantizer {
 public:
  explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {}

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override {}

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class Float8Quantizer : public Quantizer {
 public:
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax;
  DType dtype;

  explicit Float8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

class MXFP8Quantizer : public Quantizer {
 public:
  DType dtype;

  explicit MXFP8Quantizer(const py::handle& quantizer);

  NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; }

  void set_quantization_params(TensorWrapper* tensor) const override;

  std::pair<TensorWrapper, py::object> create_tensor(
      const std::vector<size_t>& shape, DType dtype,
      std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(at::Tensor t);
Przemek Tredak's avatar
Przemek Tredak committed
157
158

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
159
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
160
161

inline at::ScalarType GetATenDType(transformer_engine::DType t) {
162
163
164
165
166
167
168
169
170
171
172
173
  switch (t) {
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
174
      return at::kByte;
175
    case transformer_engine::DType::kFloat8E4M3:
176
      return at::kFloat8_e4m3fn;
177
    case transformer_engine::DType::kFloat8E5M2:
178
      return at::kFloat8_e5m2;
179
180
181
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
182
183
184
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
185
  switch (t) {
186
187
188
189
    case at::kFloat8_e4m3fn:
      return transformer_engine::DType::kFloat8E4M3;
    case at::kFloat8_e5m2:
      return transformer_engine::DType::kFloat8E5M2;
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
    default:
205
      std::cout << "Type: " << static_cast<int>(t) << std::endl;
206
207
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
208
209
210
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
211
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
212
213
214
215
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
216
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
217

218
219
220
221
222
223
224
225
226
227
228
229
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape = {1},
    const std::vector<size_t>& columnwise_scale_inv_shape = {1},
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);
Przemek Tredak's avatar
Przemek Tredak committed
230
231
232

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
233
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
234
235
236

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

237
238
239
240
241
242
243
244
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv,
    NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING);

template <typename T>
T product(const std::vector<T>& shape);
245

246
247
248
size_t product(const NVTEShape& shape, size_t begin, size_t end);

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape);
Przemek Tredak's avatar
Przemek Tredak committed
249

250
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
251
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
252

253
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
254
255
                         bool init_to_zeros = false);

256
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
257

258
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
259

260
void* getDataPtr(at::Tensor tensor, int offset = 0);
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
std::vector<size_t> convertShape(const NVTEShape& shape);

int roundup(const int value, const int multiple);

}  // namespace transformer_engine::pytorch

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
  string ret = "[";
  for (const auto& val : vec) {
    ret += to_string(val) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}

inline string to_string(const NVTEShape& s) {
  string ret = "[";
  for (size_t i = 0; i < s.ndim; ++i) {
    ret += to_string(s.data[i]) + ",";
  }
  if (ret.size() > 1) {
    ret[ret.size() - 1] = ']';
  } else {
    ret += "]";
  }
  return ret;
}
}  // namespace std

Przemek Tredak's avatar
Przemek Tredak committed
312
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_