Unverified Commit 7022d50f authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Quantizer as API (#2039)



* Introduce QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose as a first-class API
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Undo QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Make Quantizer a base class without implementations
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Support CustomRecipe and CustomRecipeState
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Resolving comments: quantize impl, num_quantizers, defaults
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Quantizer factories
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add tests
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* QuantizedTensorBase _get_quantizer() + quantize_()
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Experimental note + LayerNormMLP fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor._internal -> tensor.base
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor import fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Single quantizer factory with roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* More context for qfactory, fwd/bwd_roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename *Base -> *Storage quantized tensors
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* make_quantizers() will take roles from the operation
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Improve tests and fix missing imports
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Merge main followup
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce18bee7
......@@ -29,7 +29,7 @@ from ...module.base import (
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -568,7 +568,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass
if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
......
......@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext
......@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out = input_
elif impl == "fused":
x = input_
if not isinstance(x, Float8TensorBase):
if not isinstance(x, Float8TensorStorage):
x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused":
......
......@@ -23,7 +23,7 @@ from ...module.base import (
)
from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
......@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass
if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
......
......@@ -6,12 +6,42 @@
import torch
from .quantized_tensor import QuantizedTensor, Quantizer
from .quantized_tensor import (
QuantizedTensorStorage,
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from .storage.float8_tensor_storage import Float8TensorStorage
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer
from .utils import cast_master_weights_to_fp8, replace_raw_data
__all__ = [
"QuantizedTensor",
"Quantizer",
"Float8Quantizer",
"Float8CurrentScalingQuantizer",
"MXFP8Quantizer",
"Float8BlockQuantizer",
"NVFP4Quantizer",
"QuantizedTensorStorage",
"Float8TensorStorage",
"MXFP8TensorStorage",
"Float8BlockwiseQTensorStorage",
"NVFP4TensorStorage",
"QuantizedTensor",
"Float8Tensor",
"MXFP8Tensor",
"Float8BlockwiseQTensor",
"NVFP4Tensor",
"prepare_for_saving",
"restore_from_saved",
]
......@@ -48,24 +78,16 @@ def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
Float8TensorStorage,
MXFP8Tensor,
MXFP8TensorBase,
MXFP8TensorStorage,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
Float8BlockwiseQTensorStorage,
NVFP4Tensor,
NVFP4TensorBase,
NVFP4TensorStorage,
]
return all_tensor_types
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Internal data structures for quantized tensors."""
......@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten
......@@ -101,6 +105,10 @@ class Float8BlockQuantizer(Quantizer):
dst._fp8_dtype = self.dtype
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
......@@ -270,7 +278,7 @@ class Float8BlockQuantizer(Quantizer):
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
......@@ -295,7 +303,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
......@@ -334,15 +342,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
f" data_format={self._data_format}"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert self._quantizer is not None
return self._quantizer
def quantize_(
self,
tensor: torch.Tensor,
......@@ -361,8 +360,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
......
......@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..constants import dist_group_type
aten = torch.ops.aten
......@@ -89,6 +93,10 @@ class Float8Quantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty(
self,
shape: Iterable[int],
......@@ -147,7 +155,7 @@ class Float8Quantizer(Quantizer):
torch.float8_e5m2fnuz,
]
if internal:
return Float8TensorBase(
return Float8TensorStorage(
data=data,
fp8_scale_inv=1 / self.scale,
fp8_dtype=self.dtype,
......@@ -271,6 +279,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty(
self,
shape: Iterable[int],
......@@ -333,7 +345,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
torch.float8_e5m2fnuz,
]
if internal:
return Float8TensorBase(
return Float8TensorStorage(
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
......@@ -388,7 +400,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
return True
class Float8Tensor(Float8TensorBase, QuantizedTensor):
class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
......@@ -443,19 +455,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
return _FromFloat8Func.apply(self, dtype)
return _FromFloat8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
# Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling)
raise ValueError(
"Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable"
)
def quantize_(
self,
tensor: torch.Tensor,
......@@ -474,8 +473,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize(), noop_flag=noop_flag)
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def detach(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
......
......@@ -16,8 +16,12 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
aten = torch.ops.aten
......@@ -67,6 +71,10 @@ class MXFP8Quantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
......@@ -161,14 +169,14 @@ class MXFP8Quantizer(Quantizer):
data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor)
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor:
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
......@@ -192,7 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args,
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
......@@ -236,17 +244,9 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return _FromMXFP8Func.apply(self, dtype)
return _FromMXFP8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return MXFP8Quantizer(
fp8_dtype=self._fp8_dtype,
)
def _build_default_quantizer(self) -> Optional[Quantizer]:
"""Build default quantizer for the tensor"""
return MXFP8Quantizer(fp8_dtype=self._fp8_dtype)
def quantize_(
self,
......@@ -266,8 +266,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
return super().quantize_(tensor, noop_flag=noop_flag)
def detach(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring
......
......@@ -21,7 +21,7 @@ from ..utils import (
round_up_to_nearest_multiple,
)
from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
aten = torch.ops.aten
......@@ -173,6 +173,10 @@ class NVFP4Quantizer(Quantizer):
return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
......@@ -332,7 +336,7 @@ class NVFP4Quantizer(Quantizer):
return NVFP4BlockScaling
class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor):
class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
"""Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype,
......@@ -365,7 +369,7 @@ class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor):
Nominal tensor datatype, used in dequantize.
"""
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args,
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
......
......@@ -5,7 +5,7 @@
"""Tensor with quantized data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
......@@ -13,12 +13,11 @@ import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase:
r"""Base class for all *TensorBase classes.
class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
......@@ -26,9 +25,9 @@ class QuantizedTensorBase:
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and
XTensor inheriting from XTensorBase and QuantizedTensor.
XTensorBase should contain all data members needed to
XTensorStorage class inheriting from QuantizedTensorStorage and
XTensor inheriting from XTensorStorage and QuantizedTensor.
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
......@@ -59,7 +58,7 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement update_usage function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
......@@ -73,6 +72,30 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return self._build_default_quantizer()
def _build_default_quantizer(self) -> Quantizer:
"""Build default quantizer for the tensor"""
raise ValueError(
f"{self.__class__.__name__} has no quantizer "
"and no default quantizer is available defined in the subclass."
)
def quantize_(
self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None
) -> QuantizedTensor:
"""Quantize tensor in-place"""
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
......@@ -83,13 +106,13 @@ class QuantizedTensorBase:
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase],
*tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]]
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too."""
the internal *TensorStorage types too."""
tensor_list, tensor_objects_list = [], []
for tensor in tensors:
......@@ -104,12 +127,12 @@ def prepare_for_saving(
def restore_from_saved(
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]],
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False,
) -> (
list[Optional[torch.Tensor | QuantizedTensorBase]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]]
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
......@@ -178,7 +201,6 @@ class Quantizer(abc.ABC):
")"
)
@abc.abstractmethod
def update_quantized(
self,
src: torch.Tensor,
......@@ -187,6 +209,9 @@ class Quantizer(abc.ABC):
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Quantize tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized"
)
def quantize(
self,
......@@ -199,8 +224,14 @@ class Quantizer(abc.ABC):
if out is not None:
return self.update_quantized(tensor, out)
if (not self.internal) and torch.is_grad_enabled():
return _QuantizeFunc.apply(tensor, self)
return _QuantizeFunc.forward(None, tensor, self)
return _QuantizeFunc.apply(tensor, self.quantize_impl)
return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_impl function"
)
def multi_quantize(self, list_of_tensors):
"""Quantize multiple tensors"""
......@@ -213,7 +244,6 @@ class Quantizer(abc.ABC):
"""Quantize tensor"""
return self.quantize(tensor)
@abc.abstractmethod
def make_empty(
self,
shape: Iterable[int],
......@@ -222,8 +252,11 @@ class Quantizer(abc.ABC):
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Construct quantized tensor with uninitialized data"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function, "
"required for construction of unintialized quantized tensor"
)
@abc.abstractmethod
def calibrate(self, tensor: torch.Tensor) -> None:
"""Calibrate quantizer state
......@@ -252,13 +285,21 @@ class Quantizer(abc.ABC):
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_quantize"
)
def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_dequantize"
)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe"
)
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
......@@ -270,20 +311,21 @@ class Quantizer(abc.ABC):
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantizer: Quantizer,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return tex.quantize(tensor, quantizer)
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
......@@ -13,7 +13,7 @@ import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType_To_Torch
......@@ -22,7 +22,7 @@ from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase(QuantizedTensorBase):
class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
......@@ -53,7 +53,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
*args,
**kwargs,
):
if cls is Float8BlockwiseQTensorBase:
if cls is Float8BlockwiseQTensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -98,7 +98,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
"""
Prepare the tensor base for saving for backward
"""
......@@ -366,7 +366,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorBase("
"Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
)
......
......@@ -12,7 +12,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype
......@@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function):
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: Float8TensorBase,
tensor: Float8TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -52,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None
class Float8TensorBase(QuantizedTensorBase):
class Float8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin
......@@ -81,7 +81,7 @@ class Float8TensorBase(QuantizedTensorBase):
quantizer: Optional[Quantizer] = None,
**kwargs,
):
if cls is Float8TensorBase:
if cls is Float8TensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -116,7 +116,7 @@ class Float8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose, self._scale_inv]
self._data = None
......@@ -163,7 +163,7 @@ class Float8TensorBase(QuantizedTensorBase):
if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]:
out_transpose = None
return Float8TensorBase(
return Float8TensorStorage(
data=out_data,
fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype,
......@@ -173,7 +173,7 @@ class Float8TensorBase(QuantizedTensorBase):
def __repr__(self):
return (
"Float8TensorBase("
"Float8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
......
......@@ -13,7 +13,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype
......@@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function):
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: MXFP8TensorBase,
tensor: MXFP8TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -49,7 +49,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return grad, None
class MXFP8TensorBase(QuantizedTensorBase):
class MXFP8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin
......@@ -77,7 +77,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
*args,
**kwargs,
):
if cls is MXFP8TensorBase:
if cls is MXFP8TensorStorage:
instance = object.__new__(cls)
else:
instance = super().__new__(cls, *args, **kwargs)
......@@ -112,7 +112,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [
self._rowwise_data,
......@@ -192,7 +192,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
if cur_columnwise_data is not None:
new_columnwise_data = cur_columnwise_data.view(*shape)
return MXFP8TensorBase(
return MXFP8TensorStorage(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
......@@ -205,7 +205,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
data_rowwise = self.dequantize()
return (
"MXFP8TensorBase("
"MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"rowwise_scaled_data={data_rowwise}"
f"rowwise_scale_inv={self._rowwise_scale_inv}, "
......
......@@ -16,7 +16,7 @@ import torch
# import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ..quantized_tensor import QuantizedTensorStorage
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
......@@ -39,7 +39,7 @@ class _FromNVFP4Func(torch.autograd.Function):
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: NVFP4TensorBase,
tensor: NVFP4TensorStorage,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
......@@ -89,7 +89,7 @@ class _FromNVFP4Func(torch.autograd.Function):
return grad, None
class NVFP4TensorBase(QuantizedTensorBase):
class NVFP4TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin
......@@ -161,7 +161,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]:
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
"""Prepare the tensor base for saving for backward"""
tensors = [
self._rowwise_data,
......@@ -267,7 +267,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
new_columnwise_data = self._columnwise_data.view(byte_shape)
# Construct tensor
return NVFP4TensorBase(
return NVFP4TensorStorage(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
......@@ -282,7 +282,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
data_rowwise = self.dequantize()
return (
"NVFP4TensorBase("
"NVFP4TensorStorage("
f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise},"
......
......@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
......@@ -454,7 +454,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool:
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor.
......@@ -466,6 +466,6 @@ def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
return False
if not isinstance(x, (Quantizer, QuantizedTensorBase)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance")
if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental
......@@ -225,13 +225,15 @@ class SplitAlongDim(torch.autograd.Function):
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import (
Float8TensorStorage,
)
if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance(
mixed_x_layer, Float8Tensor
):
return tuple(
Float8TensorBase(
Float8TensorStorage(
fp8_scale_inv=mixed_x_layer._scale_inv,
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment