Unverified Commit 7022d50f authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Quantizer as API (#2039)



* Introduce QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose as a first-class API
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Undo QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Make Quantizer a base class without implementations
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Support CustomRecipe and CustomRecipeState
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Resolving comments: quantize impl, num_quantizers, defaults
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Quantizer factories
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add tests
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* QuantizedTensorBase _get_quantizer() + quantize_()
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Experimental note + LayerNormMLP fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor._internal -> tensor.base
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor import fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Single quantizer factory with roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* More context for qfactory, fwd/bwd_roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename *Base -> *Storage quantized tensors
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* make_quantizers() will take roles from the operation
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Improve tests and fix missing imports
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Merge main followup
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce18bee7
...@@ -29,7 +29,7 @@ from ...module.base import ( ...@@ -29,7 +29,7 @@ from ...module.base import (
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -568,7 +568,7 @@ class BasicLinear(BasicOperation): ...@@ -568,7 +568,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local): if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -56,7 +56,7 @@ class Dropout(BasicOperation): ...@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out = input_ out = input_
elif impl == "fused": elif impl == "fused":
x = input_ x = input_
if not isinstance(x, Float8TensorBase): if not isinstance(x, Float8TensorStorage):
x = maybe_dequantize(x, dtype=dtype) x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability) out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused": elif impl == "unfused":
......
...@@ -23,7 +23,7 @@ from ...module.base import ( ...@@ -23,7 +23,7 @@ from ...module.base import (
) )
from ...tensor.quantized_tensor import Quantizer from ...tensor.quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor from .._common import maybe_dequantize, is_quantized_tensor
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
...@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if with_quantized_compute and is_quantized_tensor(x_local): if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
......
...@@ -6,12 +6,42 @@ ...@@ -6,12 +6,42 @@
import torch import torch
from .quantized_tensor import QuantizedTensor, Quantizer from .quantized_tensor import (
QuantizedTensorStorage,
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
from .storage.float8_tensor_storage import Float8TensorStorage
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer
from .utils import cast_master_weights_to_fp8, replace_raw_data from .utils import cast_master_weights_to_fp8, replace_raw_data
__all__ = [ __all__ = [
"QuantizedTensor",
"Quantizer", "Quantizer",
"Float8Quantizer",
"Float8CurrentScalingQuantizer",
"MXFP8Quantizer",
"Float8BlockQuantizer",
"NVFP4Quantizer",
"QuantizedTensorStorage",
"Float8TensorStorage",
"MXFP8TensorStorage",
"Float8BlockwiseQTensorStorage",
"NVFP4TensorStorage",
"QuantizedTensor",
"Float8Tensor",
"MXFP8Tensor",
"Float8BlockwiseQTensor",
"NVFP4Tensor",
"prepare_for_saving",
"restore_from_saved",
] ]
...@@ -48,24 +78,16 @@ def get_all_tensor_types(): ...@@ -48,24 +78,16 @@ def get_all_tensor_types():
""" """
Get all tensor-like types that can be used in TE. Get all tensor-like types that can be used in TE.
""" """
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase
all_tensor_types = [ all_tensor_types = [
torch.Tensor, torch.Tensor,
torch.nn.Parameter, torch.nn.Parameter,
Float8Tensor, Float8Tensor,
Float8TensorBase, Float8TensorStorage,
MXFP8Tensor, MXFP8Tensor,
MXFP8TensorBase, MXFP8TensorStorage,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase, Float8BlockwiseQTensorStorage,
NVFP4Tensor, NVFP4Tensor,
NVFP4TensorBase, NVFP4TensorStorage,
] ]
return all_tensor_types return all_tensor_types
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Internal data structures for quantized tensors."""
...@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType ...@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten aten = torch.ops.aten
...@@ -101,6 +105,10 @@ class Float8BlockQuantizer(Quantizer): ...@@ -101,6 +105,10 @@ class Float8BlockQuantizer(Quantizer):
dst._fp8_dtype = self.dtype dst._fp8_dtype = self.dtype
return dst return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization. """Calculate the shape of the scaling tensor for blockwise quantization.
...@@ -270,7 +278,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -270,7 +278,7 @@ class Float8BlockQuantizer(Quantizer):
return Float8BlockScaling return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype, The tensor presents as having a standard, higher-precision dtype,
...@@ -295,7 +303,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -295,7 +303,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes. holds configuration about quantization and dequantization modes.
""" """
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++. # which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__( def __new__(
cls, cls,
...@@ -334,15 +342,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -334,15 +342,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
f" data_format={self._data_format}" f" data_format={self._data_format}"
) )
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert self._quantizer is not None
return self._quantizer
def quantize_( def quantize_(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -361,8 +360,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -361,8 +360,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
""" """
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize()) return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) return super().quantize_(tensor, noop_flag=noop_flag)
return self
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
""" """
......
...@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType ...@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..constants import dist_group_type from ..constants import dist_group_type
aten = torch.ops.aten aten = torch.ops.aten
...@@ -89,6 +93,10 @@ class Float8Quantizer(Quantizer): ...@@ -89,6 +93,10 @@ class Float8Quantizer(Quantizer):
return dst return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty( def make_empty(
self, self,
shape: Iterable[int], shape: Iterable[int],
...@@ -147,7 +155,7 @@ class Float8Quantizer(Quantizer): ...@@ -147,7 +155,7 @@ class Float8Quantizer(Quantizer):
torch.float8_e5m2fnuz, torch.float8_e5m2fnuz,
] ]
if internal: if internal:
return Float8TensorBase( return Float8TensorStorage(
data=data, data=data,
fp8_scale_inv=1 / self.scale, fp8_scale_inv=1 / self.scale,
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
...@@ -271,6 +279,10 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -271,6 +279,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
return dst return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def make_empty( def make_empty(
self, self,
shape: Iterable[int], shape: Iterable[int],
...@@ -333,7 +345,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -333,7 +345,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
torch.float8_e5m2fnuz, torch.float8_e5m2fnuz,
] ]
if internal: if internal:
return Float8TensorBase( return Float8TensorStorage(
data=data, data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
...@@ -388,7 +400,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -388,7 +400,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
return True return True
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype, The tensor presents as having a standard, higher-precision dtype,
...@@ -443,19 +455,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -443,19 +455,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
return _FromFloat8Func.apply(self, dtype) return _FromFloat8Func.apply(self, dtype)
return _FromFloat8Func.forward(None, self, dtype) return _FromFloat8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
# Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling)
raise ValueError(
"Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable"
)
def quantize_( def quantize_(
self, self,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -474,8 +473,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -474,8 +473,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
""" """
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) return self.quantize_(tensor.dequantize(), noop_flag=noop_flag)
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) return super().quantize_(tensor, noop_flag=noop_flag)
return self
def detach(self) -> Float8Tensor: def detach(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
......
...@@ -16,8 +16,12 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe ...@@ -16,8 +16,12 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
aten = torch.ops.aten aten = torch.ops.aten
...@@ -67,6 +71,10 @@ class MXFP8Quantizer(Quantizer): ...@@ -67,6 +71,10 @@ class MXFP8Quantizer(Quantizer):
return dst return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool: def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized""" """Returns whether or not given inp can be quantized"""
if inp.ndim < 2: if inp.ndim < 2:
...@@ -161,14 +169,14 @@ class MXFP8Quantizer(Quantizer): ...@@ -161,14 +169,14 @@ class MXFP8Quantizer(Quantizer):
data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor) data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor)
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor: def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype, The tensor presents as having a standard, higher-precision dtype,
...@@ -192,7 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -192,7 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
""" """
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args, # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++. # which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__( def __new__(
cls, cls,
...@@ -236,17 +244,9 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -236,17 +244,9 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return _FromMXFP8Func.apply(self, dtype) return _FromMXFP8Func.apply(self, dtype)
return _FromMXFP8Func.forward(None, self, dtype) return _FromMXFP8Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer: def _build_default_quantizer(self) -> Optional[Quantizer]:
"""Get builder for quantized tensor """Build default quantizer for the tensor"""
return MXFP8Quantizer(fp8_dtype=self._fp8_dtype)
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return MXFP8Quantizer(
fp8_dtype=self._fp8_dtype,
)
def quantize_( def quantize_(
self, self,
...@@ -266,8 +266,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -266,8 +266,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
""" """
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize()) return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) return super().quantize_(tensor, noop_flag=noop_flag)
return self
def detach(self) -> MXFP8Tensor: def detach(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
......
...@@ -21,7 +21,7 @@ from ..utils import ( ...@@ -21,7 +21,7 @@ from ..utils import (
round_up_to_nearest_multiple, round_up_to_nearest_multiple,
) )
from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
aten = torch.ops.aten aten = torch.ops.aten
...@@ -173,6 +173,10 @@ class NVFP4Quantizer(Quantizer): ...@@ -173,6 +173,10 @@ class NVFP4Quantizer(Quantizer):
return dst return dst
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
return tex.quantize(tensor, self)
def is_quantizable(self, inp: torch.Tensor) -> bool: def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized""" """Returns whether or not given inp can be quantized"""
if inp.ndim < 2: if inp.ndim < 2:
...@@ -332,7 +336,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -332,7 +336,7 @@ class NVFP4Quantizer(Quantizer):
return NVFP4BlockScaling return NVFP4BlockScaling
class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
"""Quantized tensor class with FP4 data """Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype, The tensor presents as having a standard, higher-precision dtype,
...@@ -365,7 +369,7 @@ class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): ...@@ -365,7 +369,7 @@ class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor):
Nominal tensor datatype, used in dequantize. Nominal tensor datatype, used in dequantize.
""" """
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++. # which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__( def __new__(
cls, cls,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Tensor with quantized data""" """Tensor with quantized data"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings import warnings
...@@ -13,12 +13,11 @@ import warnings ...@@ -13,12 +13,11 @@ import warnings
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase: class QuantizedTensorStorage:
r"""Base class for all *TensorBase classes. r"""Base class for all *TensorStorage classes.
This class (and its subclasses) are optimization for when This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully the full QuantizedTensor is not needed (when it is fully
...@@ -26,9 +25,9 @@ class QuantizedTensorBase: ...@@ -26,9 +25,9 @@ class QuantizedTensorBase:
PyTorch's autograd). PyTorch's autograd).
When creating a new tensor type X one should create both When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and XTensorStorage class inheriting from QuantizedTensorStorage and
XTensor inheriting from XTensorBase and QuantizedTensor. XTensor inheriting from XTensorStorage and QuantizedTensor.
XTensorBase should contain all data members needed to XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while implement the functionality of the tensor, while
XTensor should only implement the functionality needed XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__).""" to behave like regular torch.Tensor (liek __torch_dispatch__)."""
...@@ -59,7 +58,7 @@ class QuantizedTensorBase: ...@@ -59,7 +58,7 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement update_usage function" f"{self.__class__.__name__} class does not implement update_usage function"
) )
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function" f"{self.__class__.__name__} class does not implement prepare_for_saving function"
...@@ -73,6 +72,30 @@ class QuantizedTensorBase: ...@@ -73,6 +72,30 @@ class QuantizedTensorBase:
f"{self.__class__.__name__} class does not implement restore_from_saved function" f"{self.__class__.__name__} class does not implement restore_from_saved function"
) )
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return self._build_default_quantizer()
def _build_default_quantizer(self) -> Quantizer:
"""Build default quantizer for the tensor"""
raise ValueError(
f"{self.__class__.__name__} has no quantizer "
"and no default quantizer is available defined in the subclass."
)
def quantize_(
self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None
) -> QuantizedTensor:
"""Quantize tensor in-place"""
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def update_quantizer(self, quantizer: Quantizer): def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor""" """Update quantizer for the tensor"""
if self._quantizer is None: if self._quantizer is None:
...@@ -83,13 +106,13 @@ class QuantizedTensorBase: ...@@ -83,13 +106,13 @@ class QuantizedTensorBase:
def prepare_for_saving( def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorBase], *tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[ ) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]] list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
]: ]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too.""" the internal *TensorStorage types too."""
tensor_list, tensor_objects_list = [], [] tensor_list, tensor_objects_list = [], []
for tensor in tensors: for tensor in tensors:
...@@ -104,12 +127,12 @@ def prepare_for_saving( ...@@ -104,12 +127,12 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]], tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False, return_saved_tensors: bool = False,
) -> ( ) -> (
list[Optional[torch.Tensor | QuantizedTensorBase]] list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]] | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
): ):
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
...@@ -178,7 +201,6 @@ class Quantizer(abc.ABC): ...@@ -178,7 +201,6 @@ class Quantizer(abc.ABC):
")" ")"
) )
@abc.abstractmethod
def update_quantized( def update_quantized(
self, self,
src: torch.Tensor, src: torch.Tensor,
...@@ -187,6 +209,9 @@ class Quantizer(abc.ABC): ...@@ -187,6 +209,9 @@ class Quantizer(abc.ABC):
noop_flag: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Quantize tensor in-place""" """Quantize tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized"
)
def quantize( def quantize(
self, self,
...@@ -199,8 +224,14 @@ class Quantizer(abc.ABC): ...@@ -199,8 +224,14 @@ class Quantizer(abc.ABC):
if out is not None: if out is not None:
return self.update_quantized(tensor, out) return self.update_quantized(tensor, out)
if (not self.internal) and torch.is_grad_enabled(): if (not self.internal) and torch.is_grad_enabled():
return _QuantizeFunc.apply(tensor, self) return _QuantizeFunc.apply(tensor, self.quantize_impl)
return _QuantizeFunc.forward(None, tensor, self) return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_impl function"
)
def multi_quantize(self, list_of_tensors): def multi_quantize(self, list_of_tensors):
"""Quantize multiple tensors""" """Quantize multiple tensors"""
...@@ -213,7 +244,6 @@ class Quantizer(abc.ABC): ...@@ -213,7 +244,6 @@ class Quantizer(abc.ABC):
"""Quantize tensor""" """Quantize tensor"""
return self.quantize(tensor) return self.quantize(tensor)
@abc.abstractmethod
def make_empty( def make_empty(
self, self,
shape: Iterable[int], shape: Iterable[int],
...@@ -222,8 +252,11 @@ class Quantizer(abc.ABC): ...@@ -222,8 +252,11 @@ class Quantizer(abc.ABC):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Construct quantized tensor with uninitialized data""" """Construct quantized tensor with uninitialized data"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function, "
"required for construction of unintialized quantized tensor"
)
@abc.abstractmethod
def calibrate(self, tensor: torch.Tensor) -> None: def calibrate(self, tensor: torch.Tensor) -> None:
"""Calibrate quantizer state """Calibrate quantizer state
...@@ -252,13 +285,21 @@ class Quantizer(abc.ABC): ...@@ -252,13 +285,21 @@ class Quantizer(abc.ABC):
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export""" """Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_quantize"
)
def onnx_dequantize(self, tensor) -> torch.Tensor: def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export""" """Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_dequantize"
)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer""" """Returns recipe class that is compatible with this quantizer"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe"
)
def supports_only_rowwise_all_gather(self) -> bool: def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather""" """Returns True if the quantizer supports only rowwise all-gather"""
...@@ -270,20 +311,21 @@ class Quantizer(abc.ABC): ...@@ -270,20 +311,21 @@ class Quantizer(abc.ABC):
class _QuantizeFunc(torch.autograd.Function): class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Quantize tensor"""
@staticmethod @staticmethod
def forward( def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused _ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor, tensor: torch.Tensor,
quantizer: Quantizer, quantize_impl: Callable,
) -> QuantizedTensor: ) -> QuantizedTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return tex.quantize(tensor, quantizer) return quantize_impl(tensor)
@staticmethod @staticmethod
def backward( def backward(
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused _ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
...@@ -13,7 +13,7 @@ import transformer_engine_torch as tex ...@@ -13,7 +13,7 @@ import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType_To_Torch from ...constants import TE_DType_To_Torch
...@@ -22,7 +22,7 @@ from ..quantized_tensor import Quantizer ...@@ -22,7 +22,7 @@ from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor from ...utils import _empty_tensor
class Float8BlockwiseQTensorBase(QuantizedTensorBase): class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor. """Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
...@@ -53,7 +53,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -53,7 +53,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
*args, *args,
**kwargs, **kwargs,
): ):
if cls is Float8BlockwiseQTensorBase: if cls is Float8BlockwiseQTensorStorage:
instance = object.__new__(cls) instance = object.__new__(cls)
else: else:
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
...@@ -98,7 +98,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -98,7 +98,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
""" """
Prepare the tensor base for saving for backward Prepare the tensor base for saving for backward
""" """
...@@ -366,7 +366,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -366,7 +366,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
data = self.dequantize() data = self.dequantize()
descriptor = "columnwise" descriptor = "columnwise"
return ( return (
"Float8BlockwiseQTensorBase(" "Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}" f"{descriptor}_scaled_data={data}"
) )
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
...@@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused _ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: Float8TensorBase, tensor: Float8TensorStorage,
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -52,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -52,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
class Float8TensorBase(QuantizedTensorBase): class Float8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of Float8Tensor. """Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin Float8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -81,7 +81,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -81,7 +81,7 @@ class Float8TensorBase(QuantizedTensorBase):
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
**kwargs, **kwargs,
): ):
if cls is Float8TensorBase: if cls is Float8TensorStorage:
instance = object.__new__(cls) instance = object.__new__(cls)
else: else:
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
...@@ -116,7 +116,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -116,7 +116,7 @@ class Float8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer, "quantizer": self._quantizer,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose, self._scale_inv] tensors = [self._data, self._transpose, self._scale_inv]
self._data = None self._data = None
...@@ -163,7 +163,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -163,7 +163,7 @@ class Float8TensorBase(QuantizedTensorBase):
if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]: if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]:
out_transpose = None out_transpose = None
return Float8TensorBase( return Float8TensorStorage(
data=out_data, data=out_data,
fp8_scale_inv=self._scale_inv, fp8_scale_inv=self._scale_inv,
fp8_dtype=self._fp8_dtype, fp8_dtype=self._fp8_dtype,
...@@ -173,7 +173,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -173,7 +173,7 @@ class Float8TensorBase(QuantizedTensorBase):
def __repr__(self): def __repr__(self):
return ( return (
"Float8TensorBase(" "Float8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, " f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}" f"data={self.dequantize()}"
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorStorage
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
...@@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function): ...@@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused _ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: MXFP8TensorBase, tensor: MXFP8TensorStorage,
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -49,7 +49,7 @@ class _FromMXFP8Func(torch.autograd.Function): ...@@ -49,7 +49,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return grad, None return grad, None
class MXFP8TensorBase(QuantizedTensorBase): class MXFP8TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of MXFP8Tensor. """Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin MXFP8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -77,7 +77,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -77,7 +77,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
*args, *args,
**kwargs, **kwargs,
): ):
if cls is MXFP8TensorBase: if cls is MXFP8TensorStorage:
instance = object.__new__(cls) instance = object.__new__(cls)
else: else:
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
...@@ -112,7 +112,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -112,7 +112,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer, "quantizer": self._quantizer,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [ tensors = [
self._rowwise_data, self._rowwise_data,
...@@ -192,7 +192,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -192,7 +192,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
if cur_columnwise_data is not None: if cur_columnwise_data is not None:
new_columnwise_data = cur_columnwise_data.view(*shape) new_columnwise_data = cur_columnwise_data.view(*shape)
return MXFP8TensorBase( return MXFP8TensorStorage(
rowwise_data=new_rowwise_data, rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv, rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data, columnwise_data=new_columnwise_data,
...@@ -205,7 +205,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -205,7 +205,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
data_rowwise = self.dequantize() data_rowwise = self.dequantize()
return ( return (
"MXFP8TensorBase(" "MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"rowwise_scaled_data={data_rowwise}" f"rowwise_scaled_data={data_rowwise}"
f"rowwise_scale_inv={self._rowwise_scale_inv}, " f"rowwise_scale_inv={self._rowwise_scale_inv}, "
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
# import transformer_engine_torch as tex # import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorStorage
# from ...constants import TE_DType as torch_to_transformer_engine_dtype # from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
...@@ -39,7 +39,7 @@ class _FromNVFP4Func(torch.autograd.Function): ...@@ -39,7 +39,7 @@ class _FromNVFP4Func(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused _ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: NVFP4TensorBase, tensor: NVFP4TensorStorage,
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -89,7 +89,7 @@ class _FromNVFP4Func(torch.autograd.Function): ...@@ -89,7 +89,7 @@ class _FromNVFP4Func(torch.autograd.Function):
return grad, None return grad, None
class NVFP4TensorBase(QuantizedTensorBase): class NVFP4TensorStorage(QuantizedTensorStorage):
"""Mixin class that holds data attributes of NVFP4Tensor. """Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin NVFP4Tensor inherits from the PyTorch tensor class and this mixin
...@@ -161,7 +161,7 @@ class NVFP4TensorBase(QuantizedTensorBase): ...@@ -161,7 +161,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
"quantizer": self._quantizer, "quantizer": self._quantizer,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [ tensors = [
self._rowwise_data, self._rowwise_data,
...@@ -267,7 +267,7 @@ class NVFP4TensorBase(QuantizedTensorBase): ...@@ -267,7 +267,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
new_columnwise_data = self._columnwise_data.view(byte_shape) new_columnwise_data = self._columnwise_data.view(byte_shape)
# Construct tensor # Construct tensor
return NVFP4TensorBase( return NVFP4TensorStorage(
rowwise_data=new_rowwise_data, rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv, rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data, columnwise_data=new_columnwise_data,
...@@ -282,7 +282,7 @@ class NVFP4TensorBase(QuantizedTensorBase): ...@@ -282,7 +282,7 @@ class NVFP4TensorBase(QuantizedTensorBase):
data_rowwise = self.dequantize() data_rowwise = self.dequantize()
return ( return (
"NVFP4TensorBase(" "NVFP4TensorStorage("
f"rowwise_scaled_data={data_rowwise}," f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv}," f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise}," f"amax_rowwise={self._amax_rowwise},"
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
...@@ -454,7 +454,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -454,7 +454,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
) )
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware. """Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor. Returns False if x is a torch.Tensor.
...@@ -466,6 +466,6 @@ def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) - ...@@ -466,6 +466,6 @@ def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -
# Detect if the object is experimental # Detect if the object is experimental
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return False return False
if not isinstance(x, (Quantizer, QuantizedTensorBase)): if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental return hasattr(x, "experimental") and x.experimental
...@@ -225,13 +225,15 @@ class SplitAlongDim(torch.autograd.Function): ...@@ -225,13 +225,15 @@ class SplitAlongDim(torch.autograd.Function):
ctx.split_dim = split_dim ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import (
Float8TensorStorage,
)
if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance(
mixed_x_layer, Float8Tensor mixed_x_layer, Float8Tensor
): ):
return tuple( return tuple(
Float8TensorBase( Float8TensorStorage(
fp8_scale_inv=mixed_x_layer._scale_inv, fp8_scale_inv=mixed_x_layer._scale_inv,
fp8_dtype=mixed_x_layer._fp8_dtype, fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x, data=x.squeeze(split_dim) if squeeze else x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment