common.cpp 9.01 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"
8

9
10
#include "c10/util/ArrayRef.h"
#include "pybind.h"
Przemek Tredak's avatar
Przemek Tredak committed
11
#include "transformer_engine/transformer_engine.h"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
namespace transformer_engine::pytorch {

std::vector<size_t> getTensorShape(at::Tensor t) {
  std::vector<size_t> shape;
  for (auto s : t.sizes()) {
    shape.push_back(s);
  }
  return shape;
}

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
  init_extension();
  if (quantizer.is_none()) {
    return std::make_unique<NoneQuantizer>(quantizer);
  }
  for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] :
       detail::custom_types_converters) {
    if (check_quantizer_type(quantizer.ptr())) {
      return create_quantizer(quantizer);
    }
  }

  NVTE_ERROR("Unexpected type for quantizer");
}
Przemek Tredak's avatar
Przemek Tredak committed
36
37

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
38
39
40
41
42
43
                                                      const std::string& fp8_recipe) {
  // if e4m3 or hybrid + forward
  if ((fp8_recipe == "E4M3") || ((fp8_recipe == "HYBRID") && e4m3_if_hybrid)) {
    return transformer_engine::DType::kFloat8E4M3;
  }
  return transformer_engine::DType::kFloat8E5M2;
Przemek Tredak's avatar
Przemek Tredak committed
44
45
}

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) {
  NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!");
  std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
  for (auto [check_type, check_quantizer_type, create_tensor, _] :
       detail::custom_types_converters) {
    if (check_type(tensor.ptr())) {
      NVTE_CHECK(quantizer.is_none() || check_quantizer_type(quantizer.ptr()),
                 "Unexpected quantization params type.");
      auto x = create_tensor(tensor, my_quantizer.get());
      return x;
    }
  }

  // Regular pyTorch tensor
  at::Tensor torch_tensor = tensor.cast<at::Tensor>();

  // #TODO (pgadzinski) - needed in attention for non-contiguous tensors.
  //if (!torch_tensor.is_contiguous()) {
  //  torch_tensor = torch_tensor.contiguous();
  //}
  auto ret = TensorWrapper(my_quantizer->get_scaling_mode());
  ret.set_rowwise_data(torch_tensor.data_ptr(),
                       GetTransformerEngineDType(torch_tensor.scalar_type()),
                       getTensorShape(torch_tensor));
  my_quantizer->set_quantization_params(&ret);
  return ret;
}

Przemek Tredak's avatar
Przemek Tredak committed
74
transformer_engine::TensorWrapper makeTransformerEngineTensor(
75
    void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
76
77
78
79
  return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
80
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type) {
Przemek Tredak's avatar
Przemek Tredak committed
81
82
83
84
  return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
85
86
87
88
89
90
  transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
  std::vector<size_t> shape;
  for (auto s : tensor.sizes()) {
    shape.push_back(s);
  }
  return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
Przemek Tredak's avatar
Przemek Tredak committed
91
92
}

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape,
    NVTEScalingMode scaling_mode) {
  TensorWrapper ret(scaling_mode);
  ret.set_rowwise_data(data_ptr, type, shape);
  const std::vector<size_t> meta_shape{1};
  ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
  ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
  auto scale_inv_dtype =
      (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
  ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
  return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
    const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
    void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
    const std::vector<size_t>& scale_inv_shape,
    const std::vector<size_t>& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) {
  TensorWrapper ret(scaling_mode);
  ret.set_rowwise_data(data_ptr, type, shape);
  ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
  const std::vector<size_t> meta_shape{1};
  ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
  ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
  auto scale_inv_dtype =
      (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
  ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
  ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
                               columnwise_scale_inv_shape);
  return ret;
126
127
}

128
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax,
129
                                                              const at::Tensor scale,
130
131
                                                              at::Tensor scale_inv,
                                                              NVTEScalingMode scaling_mode) {
132
133
  transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());

134
135
136
  auto tensor_shape = getTensorShape(tensor);
  auto scale_inv_shape = getTensorShape(scale_inv);

137
138
139
140
  NVTE_CHECK(amax.scalar_type() == at::kFloat);
  NVTE_CHECK(scale.scalar_type() == at::kFloat);
  NVTE_CHECK(scale_inv.scalar_type() == at::kFloat);

141
142
143
  return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(),
                                     scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape,
                                     scaling_mode);
144
145
}

146
147
148
template <typename T>
T product(const std::vector<T>& shape) {
  T ret = 1;
149
150
151
152
  for (auto s : shape) {
    ret *= s;
  }
  return ret;
Przemek Tredak's avatar
Przemek Tredak committed
153
154
}

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
template size_t product<size_t>(const std::vector<size_t>& shape);
template int64_t product<int64_t>(const std::vector<int64_t>& shape);

size_t product(const NVTEShape& shape, size_t begin, size_t end) {
  NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end,
             " in a shape with ", shape.ndim, " entries");
  size_t ret = 1;
  for (size_t i = begin; i < end; ++i) {
    ret *= shape.data[i];
  }
  return ret;
}

std::vector<size_t> nvte_shape_to_vector(const NVTEShape& nvte_shape) {
  std::vector<size_t> shape;
  for (size_t i = 0; i < nvte_shape.ndim; i++) {
    shape.push_back(nvte_shape.data[i]);
  }
  return shape;
}

176
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
177
                         bool init_to_zeros) {
178
179
180
181
182
183
184
  std::vector<int64_t> shape_int64(shape.begin(), shape.end());
  c10::IntArrayRef ar_shape(shape_int64);
  if (init_to_zeros) {
    return at::zeros(ar_shape, at::CUDA(GetATenDType(type)));
  } else {
    return at::empty(ar_shape, at::CUDA(GetATenDType(type)));
  }
cyanguwa's avatar
cyanguwa committed
185
186
}

187
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
188
                         bool init_to_zeros) {
189
190
191
192
193
194
195
196
197
198
199
200
201
  auto size = shape.ndim;
  if (size == 2 && init_to_zeros) {
    return at::zeros({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
                     at::CUDA(GetATenDType(type)));
  } else if (size == 2) {
    return at::empty({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
                     at::CUDA(GetATenDType(type)));
  } else if (size == 1 && init_to_zeros) {
    return at::zeros({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
  } else if (size == 1) {
    return at::empty({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
  }
  NVTE_CHECK(false, "Should never reach here! func: allocateSpace");
Przemek Tredak's avatar
Przemek Tredak committed
202
203
}

204
205
206
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) {
  return at::empty({static_cast<int64_t>(M), static_cast<int64_t>(N)},
                   at::CUDA(GetATenDType(dtype)));
Przemek Tredak's avatar
Przemek Tredak committed
207
208
}

209
210
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype) {
  return at::empty({static_cast<int64_t>(M)}, at::CUDA(GetATenDType(dtype)));
Przemek Tredak's avatar
Przemek Tredak committed
211
}
212

213
void* getDataPtr(at::Tensor tensor, int offset) {
214
215
216
217
218
219
220
221
222
223
  void* dptr = nullptr;
  if (tensor.numel() > 0) {
    dptr = tensor.data_ptr();
  }
  if (dptr != nullptr && offset != 0) {
    char* char_ptr = reinterpret_cast<char*>(dptr);
    char_ptr += offset * tensor.element_size();
    dptr = reinterpret_cast<void*>(char_ptr);
  }
  return dptr;
224
}
225
226
227
228
229
230
231
232
233
234
235

std::vector<size_t> convertShape(const NVTEShape& shape) {
  return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}

int roundup(const int value, const int multiple) {
  assert(multiple > 0);
  return ((value + multiple - 1) / multiple) * multiple;
}

}  // namespace transformer_engine::pytorch