Unverified Commit ce2e8bd1 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Decouple python quantization classes and refactor custom quantization (#2276)



* rename experimental -> custom_recipes
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Decouple python base classes (api)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* update test_custom_recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename experimental -> custom
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* Update tests/pytorch/test_custom_recipe.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* quantization_base -> quantized_tensor rename
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2712bb95
...@@ -41,7 +41,7 @@ from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentSc ...@@ -41,7 +41,7 @@ from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentSc
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
......
...@@ -38,7 +38,7 @@ from ..distributed import ( ...@@ -38,7 +38,7 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
......
...@@ -43,7 +43,7 @@ from ..graph import is_graph_capturing ...@@ -43,7 +43,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
......
...@@ -16,7 +16,7 @@ import transformer_engine_torch as tex ...@@ -16,7 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
...@@ -56,7 +56,7 @@ from ..constants import GemmParallelModes, dist_group_type ...@@ -56,7 +56,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
...@@ -194,13 +194,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -194,13 +194,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned # Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision. # or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer) custom = is_custom(input_quantizer)
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
) )
# Apply normalization # Apply normalization
...@@ -246,8 +246,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -246,8 +246,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer quantizer = input_quantizer
# experimental recipe doesn't need to support quantized AG # custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental: if not with_quantized_norm and not custom:
ln_out = quantizer(ln_out) ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......
...@@ -17,7 +17,7 @@ import transformer_engine_torch as tex ...@@ -17,7 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
...@@ -70,7 +70,7 @@ from ..tensor.nvfp4_tensor import NVFP4Quantizer ...@@ -70,7 +70,7 @@ from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
...@@ -268,13 +268,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -268,13 +268,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm # for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer) custom = is_custom(fc1_input_quantizer)
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not experimental and not custom
) )
# Apply normalization # Apply normalization
...@@ -314,8 +314,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -314,8 +314,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer quantizer = fc1_input_quantizer
# experimental recipe doesn't need to support quantized AG # custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental: if not with_quantized_norm and not custom:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
......
...@@ -57,7 +57,7 @@ from ..cpp_extensions import ( ...@@ -57,7 +57,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
...@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import ( ...@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import (
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -153,8 +153,8 @@ class _Linear(torch.autograd.Function): ...@@ -153,8 +153,8 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8) ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG ub_type = tex.CommOverlapType.AG
# experimental recipe check # custom recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
# ------------------------------------------------------ # ------------------------------------------------------
# Prepare input tensor # Prepare input tensor
...@@ -178,7 +178,7 @@ class _Linear(torch.autograd.Function): ...@@ -178,7 +178,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorStorage) and not experimental: if not isinstance(inputmat, QuantizedTensorStorage) and not custom:
own_quantized_input = True own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance( if isinstance(
...@@ -448,7 +448,7 @@ class _Linear(torch.autograd.Function): ...@@ -448,7 +448,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug ctx.debug = debug
ctx.experimental = experimental ctx.custom = custom
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
...@@ -616,7 +616,7 @@ class _Linear(torch.autograd.Function): ...@@ -616,7 +616,7 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat, QuantizedTensorStorage): if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized # Input tensor is already quantized
pass pass
elif ctx.debug or ctx.experimental: elif ctx.debug or ctx.custom:
# Debug quantizer will be applied immediately before wgrad GEMM # Debug quantizer will be applied immediately before wgrad GEMM
pass pass
else: else:
......
...@@ -13,7 +13,7 @@ from transformer_engine_torch import FP8TensorMeta ...@@ -13,7 +13,7 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version from .. import torch_version
from ..quantization import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorStorage from ..quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype from ..utils import canonicalize_dtype
......
...@@ -21,7 +21,7 @@ from ...module.base import ( ...@@ -21,7 +21,7 @@ from ...module.base import (
get_ub, get_ub,
get_workspace, get_workspace,
) )
from ...tensor.quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
......
...@@ -21,7 +21,7 @@ from ...module.base import ( ...@@ -21,7 +21,7 @@ from ...module.base import (
get_workspace, get_workspace,
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...tensor.quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor from .._common import maybe_dequantize, is_quantized_tensor
......
...@@ -28,7 +28,7 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -28,7 +28,7 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear, fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear, fuse_userbuffers_forward_linear,
) )
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Tensor with quantized data""" """Pure Python base classes for quantization."""
from __future__ import annotations from __future__ import annotations
from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc import abc
import copy import copy
import warnings import warnings
...@@ -14,6 +14,11 @@ import torch ...@@ -14,6 +14,11 @@ import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
_IdentityFunc,
_stride_from_shape,
)
class QuantizedTensorStorage: class QuantizedTensorStorage:
...@@ -310,73 +315,6 @@ class Quantizer(abc.ABC): ...@@ -310,73 +315,6 @@ class Quantizer(abc.ABC):
return True return True
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
class QuantizedTensor(torch.Tensor): class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data """Abstract base class for tensor with quantized data
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch import torch
from .quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Private helper functions and classes for quantized tensor implementations.
This module contains internal autograd functions and utilities that support
the quantization machinery.
"""
from __future__ import annotations
from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
"""Calculate stride from shape for contiguous tensors"""
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
...@@ -14,11 +14,8 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat ...@@ -14,11 +14,8 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import ( from ..quantized_tensor import QuantizedTensor, Quantizer
QuantizedTensor, from ._quantization_helpers import _IdentityFunc
Quantizer,
_IdentityFunc,
)
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten aten = torch.ops.aten
......
...@@ -14,11 +14,8 @@ from transformer_engine_torch import DType as TE_DType ...@@ -14,11 +14,8 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import ( from ..quantized_tensor import QuantizedTensor, Quantizer
QuantizedTensor, from ._quantization_helpers import _IdentityFunc
Quantizer,
_IdentityFunc,
)
from ..constants import dist_group_type from ..constants import dist_group_type
aten = torch.ops.aten aten = torch.ops.aten
......
...@@ -17,11 +17,8 @@ from ..constants import MXFP8_BLOCK_SCALING_SIZE ...@@ -17,11 +17,8 @@ from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import ( from ..quantized_tensor import QuantizedTensor, Quantizer
QuantizedTensor, from ._quantization_helpers import _IdentityFunc
Quantizer,
_IdentityFunc,
)
aten = torch.ops.aten aten = torch.ops.aten
......
...@@ -22,7 +22,8 @@ from ..utils import ( ...@@ -22,7 +22,8 @@ from ..utils import (
) )
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten aten = torch.ops.aten
......
...@@ -13,12 +13,10 @@ import transformer_engine_torch as tex ...@@ -13,12 +13,10 @@ import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorStorage from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType_To_Torch from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor from ...utils import _empty_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment