Unverified Commit 1bbeab1c authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise float8 quantizer and quantized tensor class (#1513)



* Blockwise float8 quantizer and quantized tensor class.

The classes are configurable for 128x128 blocksize
and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively.

Scale tensors are stored in a format emenable for matrix multiplication,
however the integration of matmul is deferred as a separate story.

Fusions of quantization and DBIAS or activation functions are not yet
implemented, and the dequantization is currently implemented in torch.

Tests for quantization are included in C++ and pytorch layers, with
exact comparison to reference quantizer behavior as well as an attempt
to hit interesting branches through the API such as tensor creation
in pytorch and CPP and dequantization of row and columnwise usage.

Two CUDA kernels for quantization are included, and are direct ports
of equivalents in the kitchen repository, where a subchannel recipe
has been used for end to end training.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Apply linting changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Alignment for 1D scaling for GEMM edge case.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix merge conflict with name change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use common tensor map API.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API to use two scaling mode enums.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update some call sites.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Tests for torch tensor API surface.

Since the quantized tensor is a tensor
subclass, these tests exercise torch hooks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse scale calculation between quantizer refs.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Save memory by dropping reference to saved tensors.

Issues previously observed are solved.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove constexpr parameters from kernel.

Code size is reduced with fewer constexpr params.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Merge conflict from rebase.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add shape implementations for block scaling.

nvte_shape was added upstream. Logic added
for block scaled fp8.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Move benchmark to te_playground
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove amax_epsilon and pow_2_scales from tensor.

Hardcodes the default values.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Lint changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fixup MR changes that broke.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Safer ifdef in kernel.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Documentation prose.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse compute_scale function from Current Scaling.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Bugfix on inf_value scale refactor.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove qopt calls from test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytest list.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to reference scale calc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use ptx.cuh functions instead of cde.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update shape logic with allocation and reuse shape.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Usage defaults MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Copyright and header guard.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Updating torch dispatch code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix exception type.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use TypeInfo
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update CS scale update test to use updated ref impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX scaling mode enum
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Skip tests on Lovelace
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 3e305f72
...@@ -1262,6 +1262,30 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1262,6 +1262,30 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
workspace_tensor, stream); workspace_tensor, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
constexpr bool force_pow_2_scales = true;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
constexpr bool force_pow_2_scales = true;
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
default: default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
} }
......
...@@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
} }
} else { } else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
} }
} }
......
...@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( ...@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
: "memory"); : "memory");
} }
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global // shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
...@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( ...@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
: "memory"); : "memory");
} }
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() { __device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group 0;"); asm volatile("cp.async.bulk.wait_group 0;");
...@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { ...@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm volatile("cp.async.bulk.wait_group.read 4;"); asm volatile("cp.async.bulk.wait_group.read 4;");
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// Proxy fence (bi-directional): // Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } __device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); }
__device__ __forceinline__ void fence_proxy_async_shared_cta() { __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -183,8 +183,8 @@ class ScalingMode(Enum): ...@@ -183,8 +183,8 @@ class ScalingMode(Enum):
NVTE_DELAYED_TENSOR_SCALING = 0 NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1 NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 2 NVTE_INVALID_SCALING = 4
NVTE_NO_SCALING = 3 NVTE_NO_SCALING = 5
def _get_impl(self) -> ScalingModeMetadataImpl: def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode. """Get the implementation for this scaling mode.
......
...@@ -24,6 +24,12 @@ TE_DType = { ...@@ -24,6 +24,12 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch = { TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8, tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn, tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
......
...@@ -163,6 +163,38 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -163,6 +163,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
}; };
class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
DType dtype;
private:
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
int block_scaling_dim = 2;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit Float8BlockQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override {
return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void set_quantization_params(TensorWrapper* tensor) const override;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer { class MXFP8Quantizer : public Quantizer {
public: public:
DType dtype; DType dtype;
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// #include <torch/all.h> // #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream> #include <sstream>
...@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor { ...@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n -= chunk_idx * chunk_size; n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, float scale_val = transformer_engine::compute_scale_from_amax(
force_pow_2_scales, epsilon); amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits<float>::max());
scale[i_start] = scale_val; scale[i_start] = scale_val;
transformer_engine::reciprocal(scale_inv + i_start, scale_val); transformer_engine::reciprocal(scale_inv + i_start, scale_val);
} }
......
...@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; ...@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
void init_float8_extension() { void init_float8_extension() {
if (Float8TensorPythonClass) return; if (Float8TensorPythonClass) return;
...@@ -61,9 +64,31 @@ void init_mxfp8_extension() { ...@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension."); "Internal error: could not initialize pyTorch MXFP8 extension.");
} }
void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
}
void init_extension() { void init_extension() {
init_float8_extension(); init_float8_extension();
init_mxfp8_extension(); init_mxfp8_extension();
init_float8blockwise_extension();
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output") = py::none(), py::arg("noop") = py::none()); py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype")); py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
......
...@@ -250,6 +250,142 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -250,6 +250,142 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1}); tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
} }
this->set_quantization_params(&tensor); this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast<bool>(),
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires pow2 scales");
NVTE_CHECK(quantizer.attr("amax_epsilon").cast<float>() == 0.0,
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires amax_epsilon==0");
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(this->get_scaling_mode());
at::TensorOptions opts;
at::TensorOptions scale_opts;
at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
if (columnwise_usage) {
std::vector<int64_t> torch_columnwise_shape;
std::vector<size_t> columnwise_shape;
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts);
tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape);
tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);
py::object ret;
if (internal) {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
ret = Float8BlockwiseQTensorClass(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
ret = Float8BlockwiseQTensorClass(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
}
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
...@@ -302,8 +438,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -302,8 +438,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, tensor.set_rowwise_scale_inv(
std::vector<size_t>{sinv0, sinv1}); rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
if (columnwise_usage) { if (columnwise_usage) {
...@@ -313,8 +450,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -313,8 +450,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, tensor.set_columnwise_scale_inv(
std::vector<size_t>{sinv0, sinv1}); columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
} }
this->set_quantization_params(&tensor); this->set_quantization_params(&tensor);
......
...@@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
return ret; return ret;
} }
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto &rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto &shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
quantizer->set_quantization_params(&ret);
return ret;
}
} // namespace detail } // namespace detail
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass; ...@@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass;
void init_extension(); void init_extension();
...@@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) { ...@@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
} }
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
}
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass;
}
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
template <typename T> template <typename T>
...@@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati ...@@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati
std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params); std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
Quantizer *quantization_params);
inline bool IsFloatingPointType(at::ScalarType type) { inline bool IsFloatingPointType(at::ScalarType type) {
return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
} }
...@@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = { ...@@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = {
std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8CurrentScalingQuantizer>), CreateQuantizer<Float8CurrentScalingQuantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>)}; CreateQuantizer<MXFP8Quantizer>),
std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>)};
} // namespace detail } // namespace detail
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for Float8BlockwiseQTensor"""
from __future__ import annotations
import math
from typing import Optional, Dict, Any, Tuple
import torch
from transformer_engine_torch import DType as TE_DType
from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
class Float8BlockwiseQTensorBase:
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
mixin class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data: Optional[torch.Tensor]
_columnwise_data: Optional[torch.Tensor]
_quantizer: Quantizer
_fp8_dtype: TE_DType
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
def __new__(
cls,
*args,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
return instance
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
"rowwise_data": self._rowwise_data,
"rowwise_scale_inv": self._rowwise_scale_inv,
"columnwise_data": self._columnwise_data,
"columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
}
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._rowwise_data, self._columnwise_data]
self._rowwise_data = None
self._columnwise_data = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
return tensors[2:]
def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
if columnwise_dq.dim() < 2:
return columnwise_dq
permute_dims = list(range(1, columnwise_dq.dim()))
permute_dims.append(0)
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()
def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
block_len = 128
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
_, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len)
if scale_m > q_M:
# scale_m is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K]
result = result.to(dtype)
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
block_len = 128
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
_, scale_K = scales.shape
if derived_scale_k_shape == scale_K:
return scales
return scales[:, :derived_scale_k_shape].contiguous()
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
formatted_scales = format_scale_as_logical_shape(q_K, scale_inv, block_len)
assert len(formatted_scales.shape) == 2
m_tiles, k_tiles = formatted_scales.shape
unpadded_m, unpadded_k = q_M, q_K
m_block_len = block_len
k_block_len = block_len
if q_M % m_block_len != 0 or q_K % k_block_len != 0:
m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len
k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view(
m_tiles, 1, k_tiles, 1
)
result = result.view(padded_M, padded_K).to(dtype)
if padded_M != unpadded_m or padded_K != unpadded_k:
result = result[:unpadded_m, :unpadded_k]
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
reordered.append(dims[0])
return torch.Size(reordered)
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
descriptor = "rowwise"
else:
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorBase("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment