Unverified Commit ce2e8bd1 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Decouple python quantization classes and refactor custom quantization (#2276)



* rename experimental -> custom_recipes
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Decouple python base classes (api)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* update test_custom_recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename experimental -> custom
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* Update tests/pytorch/test_custom_recipe.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* quantization_base -> quantized_tensor rename
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2712bb95
...@@ -12,12 +12,10 @@ import torch ...@@ -12,12 +12,10 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
......
...@@ -13,12 +13,10 @@ import torch ...@@ -13,12 +13,10 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor from ...utils import _empty_tensor
......
...@@ -16,10 +16,9 @@ import torch ...@@ -16,10 +16,9 @@ import torch
# import transformer_engine_torch as tex # import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage from ...quantized_tensor import QuantizedTensorStorage, Quantizer
# from ...constants import TE_DType as torch_to_transformer_engine_dtype # from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor from ...utils import _empty_tensor
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
"""Helper functions for using fp8 tensors as weights""" """Helper functions for using fp8 tensors as weights"""
import os from typing import Optional, Union, List
from typing import Optional, List, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
...@@ -459,18 +459,13 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten ...@@ -459,18 +459,13 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten
raise ValueError(f"post_processing for {type(model_weight)} is not supported") raise ValueError(f"post_processing for {type(model_weight)} is not supported")
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware. """Check if an object is custom.
Returns False if x is a torch.Tensor. Returns False if x is a torch.Tensor.
""" """
# Detect if the environment is experimental if x is None or isinstance(x, torch.Tensor):
if x is None:
return int(os.getenv("QAT_PARAMS", "0")) > 0
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
return False return False
if not isinstance(x, (Quantizer, QuantizedTensorStorage)): if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance") raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental return hasattr(x, "custom") and x.custom
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
import torch import torch
from . import torch_version from . import torch_version
from .tensor.quantized_tensor import Quantizer from .quantized_tensor import Quantizer
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment