pybind.h 4.57 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#define PYBIND11_DETAILED_ERROR_MESSAGES  // TODO remove

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
11
12

#include <Python.h>
13
14
15
16
17
18
19
20
21
22
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <torch/torch.h>

#include "common.h"
#include "transformer_engine/transformer_engine.h"

namespace transformer_engine::pytorch {

23
24
25
26
27
28
29
30
31
32
#define NVTE_SCOPED_GIL_RELEASE(code_block)      \
  do {                                           \
    if (PyGILState_Check()) {                    \
      pybind11::gil_scoped_release _gil_release; \
      code_block                                 \
    } else {                                     \
      code_block                                 \
    }                                            \
  } while (false);

33
extern PyTypeObject *Float8TensorPythonClass;
34
extern PyTypeObject *Float8TensorStoragePythonClass;
35
extern PyTypeObject *Float8QuantizerClass;
36
extern PyTypeObject *Float8CurrentScalingQuantizerClass;
37
extern PyTypeObject *MXFP8TensorPythonClass;
38
extern PyTypeObject *MXFP8TensorStoragePythonClass;
39
extern PyTypeObject *MXFP8QuantizerClass;
40
extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
41
extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass;
42
extern PyTypeObject *Float8BlockwiseQuantizerClass;
43
extern PyTypeObject *NVFP4TensorPythonClass;
44
extern PyTypeObject *NVFP4TensorStoragePythonClass;
45
extern PyTypeObject *NVFP4QuantizerClass;
46
47
48
49
50

void init_extension();

namespace detail {

51
52
53
54
55
inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }

inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
  return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass;
}
56
57

inline bool IsFloat8Tensor(PyObject *obj) {
58
  return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass;
59
60
}

61
inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
62
63

inline bool IsMXFP8Tensor(PyObject *obj) {
64
  return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass;
65
66
}

67
68
69
70
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
  return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
}

71
72
inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; }

73
74
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
  return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
75
         Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass;
76
77
}

78
inline bool IsNVFP4Tensor(PyObject *obj) {
79
  return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass;
80
81
}

82
83
84
85
86
87
88
89
90
91
92
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);

template <typename T>
std::unique_ptr<Quantizer> CreateQuantizer(const py::handle quantizer) {
  return std::make_unique<T>(quantizer);
}

TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params);

std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);

93
94
95
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
                                                   Quantizer *quantization_params);

96
97
TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer);

98
99
100
101
102
inline bool IsFloatingPointType(at::ScalarType type) {
  return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
}

constexpr std::array custom_types_converters = {
103
    std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor,
104
                    CreateQuantizer<Float8Quantizer>),
105
106
107
    std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
                    CreateQuantizer<Float8CurrentScalingQuantizer>),
    std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
108
109
                    CreateQuantizer<MXFP8Quantizer>),
    std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
110
111
112
                    NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>),
    std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor,
                    CreateQuantizer<NVFP4Quantizer>)};
113
114
115
116
117
}  // namespace detail

}  // namespace transformer_engine::pytorch

#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_