Unverified Commit ce2e8bd1 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Decouple python quantization classes and refactor custom quantization (#2276)



* rename experimental -> custom_recipes
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Decouple python base classes (api)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* update test_custom_recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename experimental -> custom
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* Update tests/pytorch/test_custom_recipe.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* quantization_base -> quantized_tensor rename
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2712bb95
......@@ -41,7 +41,7 @@ from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentSc
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
......
......@@ -38,7 +38,7 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
......
......@@ -43,7 +43,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
......
......@@ -16,7 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -56,7 +56,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
......@@ -194,13 +194,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer)
custom = is_custom(input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
......@@ -246,8 +246,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
# custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not custom:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......
......@@ -17,7 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -70,7 +70,7 @@ from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
......@@ -268,13 +268,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer)
custom = is_custom(fc1_input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
and not custom
)
# Apply normalization
......@@ -314,8 +314,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
# custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not custom:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......
......@@ -57,7 +57,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
......@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import (
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental
from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
......@@ -153,8 +153,8 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# experimental recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer)
# custom recipe check
custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
# ------------------------------------------------------
# Prepare input tensor
......@@ -178,7 +178,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorStorage) and not experimental:
if not isinstance(inputmat, QuantizedTensorStorage) and not custom:
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
......@@ -448,7 +448,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug
ctx.experimental = experimental
ctx.custom = custom
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -616,7 +616,7 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized
pass
elif ctx.debug or ctx.experimental:
elif ctx.debug or ctx.custom:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
......
......@@ -13,7 +13,7 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorStorage
from ..quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype
......
......@@ -21,7 +21,7 @@ from ...module.base import (
get_ub,
get_workspace,
)
from ...tensor.quantized_tensor import Quantizer
from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
......
......@@ -21,7 +21,7 @@ from ...module.base import (
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import Quantizer
from ...quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor
......
......@@ -28,7 +28,7 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
......
......@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
......
......@@ -2,10 +2,10 @@
#
# See LICENSE for license information.
"""Tensor with quantized data"""
"""Pure Python base classes for quantization."""
from __future__ import annotations
from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
......@@ -14,6 +14,11 @@ import torch
from torch.utils._pytree import tree_map
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
_IdentityFunc,
_stride_from_shape,
)
class QuantizedTensorStorage:
......@@ -310,73 +315,6 @@ class Quantizer(abc.ABC):
return True
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
......
......@@ -6,7 +6,7 @@
import torch
from .quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensorStorage,
QuantizedTensor,
Quantizer,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Private helper functions and classes for quantized tensor implementations.
This module contains internal autograd functions and utilities that support
the quantization machinery.
"""
from __future__ import annotations
from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
"""Calculate stride from shape for contiguous tensors"""
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
......@@ -14,11 +14,8 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten
......
......@@ -14,11 +14,8 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ..constants import dist_group_type
aten = torch.ops.aten
......
......@@ -17,11 +17,8 @@ from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
......
......@@ -22,7 +22,8 @@ from ..utils import (
)
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
......
......@@ -13,12 +13,10 @@ import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment