quantizer.cpp 20.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <pybind.h>

#include "common.h"
#include "pybind.h"
#include "torch/torch.h"

namespace transformer_engine::pytorch {

constexpr size_t MXFP8_BLOCK_SIZE = 32;

Quantizer::Quantizer(const py::handle& quantizer) {
  if (quantizer.is_none()) {
    this->rowwise_usage = true;
    this->columnwise_usage = true;
    this->internal = false;
  } else {
    this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>();
    this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>();
    this->internal = quantizer.attr("internal").cast<bool>();
    this->quantizer = quantizer;
  }
}

Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
  const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
  const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
  const DType type = quantizer.attr("dtype").cast<DType>();

  this->amax = amax;
  this->scale = scale;
  this->dtype = type;
}

std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(
    const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
  at::TensorOptions opts;
  opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA);
  std::vector<int64_t> torch_shape;
  for (auto s : shape) {
    torch_shape.emplace_back(static_cast<int64_t>(s));
  }
  at::Tensor ret;
  if (rowwise_data.has_value()) {
    ret = std::move(*rowwise_data);
  } else {
    ret = at::empty(torch_shape, opts);
  }

  TensorWrapper tensor;
  tensor.set_rowwise_data(ret.data_ptr(), dtype, shape);
  return {std::move(tensor), py::cast(ret)};
}

void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
  tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
                    getTensorShape(scale));
  at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
  tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
                   getTensorShape(amax));
  auto rowwise_data = tensor->get_rowwise_data();
  rowwise_data.dtype = static_cast<NVTEDType>(dtype);

  auto columnwise_data = tensor->get_columnwise_data();
  columnwise_data.dtype = static_cast<NVTEDType>(dtype);

  tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
                           rowwise_data.shape);
  tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
                              columnwise_data.shape);
}

std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
    const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
  using namespace pybind11::literals;
  std::vector<int64_t> rowwise_torch_shape;
  std::vector<int64_t> columnwise_torch_shape;

  if (!shape.empty()) {
    columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
  }
  for (size_t i = 0; i < shape.size(); ++i) {
    if (i < shape.size() - 1) {
      columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
    }
    rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
  }
  at::TensorOptions opts;
  opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
  at::Tensor data;
  if (rowwise_usage) {
    if (rowwise_data.has_value()) {
      data = std::move(*rowwise_data);
    } else {
      data = at::empty(rowwise_torch_shape, opts);
    }
  }
  const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
  at::Tensor columnwise_data;
105
  bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
106
107
108
109
110
  if (create_transpose) {
    columnwise_data = at::empty(columnwise_torch_shape, opts);
  }
  const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
  opts = opts.dtype(torch::kFloat32);
111
  // TODO: Replace with an empty tensor.
112
  at::Tensor scale_inv = at::reciprocal(scale);
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  py::object ret;
  if (internal) {
    py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
    ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
                            "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
                            "quantizer"_a = this->quantizer);
  } else {
    py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
    ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
                            "data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
                            "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
                            "quantizer"_a = this->quantizer);
  }
  TensorWrapper tensor(this->get_scaling_mode());
  if (rowwise_usage) {
    tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
    tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
  }
  if (create_transpose) {
    std::vector<size_t> transposed_shape;
    for (auto s : columnwise_torch_shape) {
      transposed_shape.emplace_back(static_cast<size_t>(s));
    }
    tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
    tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
  }
  this->set_quantization_params(&tensor);
  return {std::move(tensor), std::move(ret)};
}

Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer)
    : Quantizer(quantizer) {
  const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
  const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
  const DType type = quantizer.attr("dtype").cast<DType>();
  this->amax = amax;
  this->scale = scale;
  this->dtype = type;
151
152
153
154
155
156
157
158
159
160

  // Get amax reduction group if needed
  const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
  c10::intrusive_ptr<dist_group_type> amax_reduction_group;
  if (with_amax_reduction) {
    auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
    NVTE_CHECK(!group.is_none(),
               "Float8CurrentScalingQuantizer could not canonicalize amax reduction group");
    amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
  }
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  this->with_amax_reduction = with_amax_reduction;
  this->amax_reduction_group = amax_reduction_group;

  // fp8 current scaling specific quantization params
  this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
  this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
}

void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const {
  // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them)
  tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
                    getTensorShape(scale));
  at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
  tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
                   getTensorShape(amax));
  // quantize output and its transpose
  auto rowwise_data = tensor->get_rowwise_data();
  rowwise_data.dtype = static_cast<NVTEDType>(dtype);

  auto columnwise_data = tensor->get_columnwise_data();
  columnwise_data.dtype = static_cast<NVTEDType>(dtype);

  tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
                           rowwise_data.shape);
  tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
                              columnwise_data.shape);
}

std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
    const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
  using namespace pybind11::literals;
  std::vector<int64_t> rowwise_torch_shape;
  std::vector<int64_t> columnwise_torch_shape;
  std::vector<int64_t> scale_inv_torch_shape = {1};  // Shape of 1 element for scale_inv

  if (!shape.empty()) {
    columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
  }
  for (size_t i = 0; i < shape.size(); ++i) {
    if (i < shape.size() - 1) {
      columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
    }
    rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
  }
  at::TensorOptions opts;
  opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
  at::Tensor data;
  if (rowwise_usage) {
    if (rowwise_data.has_value()) {
      data = std::move(*rowwise_data);
    } else {
      data = at::empty(rowwise_torch_shape, opts);
    }
  }
  const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
  at::Tensor columnwise_data;
217
  bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
218
219
220
221
222
  if (create_transpose) {
    columnwise_data = at::empty(columnwise_torch_shape, opts);
  }
  const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();

223
224
  // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set.
  at::Tensor scale_inv = at::reciprocal(scale);
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
  py::object ret;
  if (internal) {
    py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
    ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
                            "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
                            "quantizer"_a = this->quantizer);
  } else {
    py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
    ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
                            "data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
                            "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
                            "quantizer"_a = this->quantizer);
  }
  TensorWrapper tensor(this->get_scaling_mode());
  if (rowwise_usage) {
    tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
    tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
  }
  if (create_transpose) {
    std::vector<size_t> transposed_shape;
    for (auto s : columnwise_torch_shape) {
      transposed_shape.emplace_back(static_cast<size_t>(s));
    }
    tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
    tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
  }
  this->set_quantization_params(&tensor);
253
254
255
256
257
258
259

  return {std::move(tensor), std::move(ret)};
}

Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
  this->dtype = quantizer.attr("dtype").cast<DType>();
  this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
260
261
  this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
  this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
             "Unsupported block scaling dim.");
}

void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
  // Change the rowwise and columnwise_data to the configured dtype.
  // May be a switch between E5M2 and E4M3.
  auto rowwise_data = tensor->get_rowwise_data();
  rowwise_data.dtype = static_cast<NVTEDType>(dtype);

  auto columnwise_data = tensor->get_columnwise_data();
  columnwise_data.dtype = static_cast<NVTEDType>(dtype);

  tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
                           rowwise_data.shape);
  tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
                              columnwise_data.shape);
}

std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
    const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
  using namespace pybind11::literals;
  std::vector<int64_t> torch_shape;
  size_t numel = 1;
  for (auto s : shape) {
    torch_shape.emplace_back(static_cast<int64_t>(s));
    numel *= s;
  }

  TensorWrapper tensor(this->get_scaling_mode());
  at::TensorOptions opts;
  at::TensorOptions scale_opts;
  at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise;
  opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
  scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);

  size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
  size_t m_dim = numel / k_dim;
  constexpr size_t kBlockLen = 128;

  if (rowwise_usage) {
    if (rowwise_data.has_value()) {
      data_rowwise = std::move(*rowwise_data);
    } else {
      data_rowwise = at::empty(torch_shape, opts);
    }
    size_t sinv0 = 0;
    size_t sinv1 = 0;
    if (block_scaling_dim == 2) {
      sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
      sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
    } else if (block_scaling_dim == 1) {
      sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
      sinv1 = roundup(m_dim, 4);
    } else {
      NVTE_CHECK(false,
                 "Unsupported block_scaling_dim in create_tensor rowwise."
                 "Expected 1 or 2. Got ",
                 block_scaling_dim);
    }
Xin Yao's avatar
Xin Yao committed
322
323
    scale_inv_rowwise =
        at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
    tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32,
                                 std::vector<size_t>{sinv0, sinv1});
  }

  if (columnwise_usage) {
    std::vector<int64_t> torch_columnwise_shape;
    std::vector<size_t> columnwise_shape;
    NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
               columnwise_shape, " torch shape: ", torch_columnwise_shape);
    if (torch_shape.size() > 0) {
      torch_columnwise_shape.reserve(torch_shape.size());
      columnwise_shape.reserve(shape.size());
      torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
      columnwise_shape.push_back(shape[shape.size() - 1]);
      for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
        torch_columnwise_shape.push_back(torch_shape[i]);
        columnwise_shape.push_back(shape[i]);
      }
    }
    size_t sinv0 = 0;
    size_t sinv1 = 0;
    if (block_scaling_dim == 2) {
      sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
      sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
    } else if (block_scaling_dim == 1) {
      sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
      sinv1 = roundup(k_dim, 4);
    } else {
      NVTE_CHECK(false,
                 "Unsupported block_scaling_dim in create_tensor columnwise."
                 "Expected 1 or 2. Got ",
                 block_scaling_dim);
    }
    data_colwise = at::empty(torch_columnwise_shape, opts);
Xin Yao's avatar
Xin Yao committed
359
360
    scale_inv_colwise =
        at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

    tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape);
    tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32,
                                    std::vector<size_t>{sinv0, sinv1});
  }
  this->set_quantization_params(&tensor);

  py::object ret;
  if (internal) {
    py::handle Float8BlockwiseQTensorClass(
        reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
    ret = Float8BlockwiseQTensorClass(
        "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
        "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
        "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
        "is_2D_scaled"_a = (block_scaling_dim == 2));
  } else {
    py::handle Float8BlockwiseQTensorClass(
        reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
    ret = Float8BlockwiseQTensorClass(
        "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
        "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
        "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
        "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
  }

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
  return {std::move(tensor), std::move(ret)};
}

MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
  this->dtype = quantizer.attr("dtype").cast<DType>();
}

void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
  auto rowwise_data = tensor->get_rowwise_data();
  rowwise_data.dtype = static_cast<NVTEDType>(dtype);

  auto columnwise_data = tensor->get_columnwise_data();
  columnwise_data.dtype = static_cast<NVTEDType>(dtype);

  tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
                           rowwise_data.shape);
  tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
                              columnwise_data.shape);
}

std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
    const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
  using namespace pybind11::literals;
  std::vector<int64_t> torch_shape;
  size_t numel = 1;
  for (auto s : shape) {
    torch_shape.emplace_back(static_cast<int64_t>(s));
    numel *= s;
  }

  TensorWrapper tensor(NVTE_MXFP8_1D_SCALING);
  at::TensorOptions opts;
  at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
      columnwise_scale_inv;  // TODO(pgadzinski) - change
  opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
  auto last_dim = static_cast<size_t>(torch_shape.back());

  NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
             "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
             " (got shape=", torch_shape, ")");

  at::Tensor data;
  if (rowwise_usage) {
    if (rowwise_data.has_value()) {
      data = std::move(*rowwise_data);
    } else {
      data = at::empty(torch_shape, opts);
    }
    auto sinv0 = roundup(numel / last_dim, 128);
    auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
    rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
    tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
439
440
441
    tensor.set_rowwise_scale_inv(
        rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
        std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
442
443
444
445
446
447
448
449
450
  }

  if (columnwise_usage) {
    auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
    auto sinv1 = roundup(last_dim, 128);
    columnwise_data = at::empty(torch_shape, opts);
    columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);

    tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
451
452
453
    tensor.set_columnwise_scale_inv(
        columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
        std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
  }
  this->set_quantization_params(&tensor);

  py::object ret;
  if (internal) {
    py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
    ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
                           "rowwise_scale_inv"_a = rowwise_scale_inv,
                           "columnwise_scale_inv"_a = columnwise_scale_inv,
                           "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
  } else {
    py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
    ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype),
                           "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
                           "rowwise_scale_inv"_a = rowwise_scale_inv,
                           "columnwise_scale_inv"_a = columnwise_scale_inv,
                           "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
  }

  return {std::move(tensor), std::move(ret)};
}

}  // namespace transformer_engine::pytorch