Unverified Commit ce2e8bd1 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Decouple python quantization classes and refactor custom quantization (#2276)



* rename experimental -> custom_recipes
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Decouple python base classes (api)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* update test_custom_recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename experimental -> custom
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* Update tests/pytorch/test_custom_recipe.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* quantization_base -> quantized_tensor rename
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2712bb95
......@@ -12,12 +12,10 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
......
......@@ -13,12 +13,10 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
......
......@@ -16,10 +16,9 @@ import torch
# import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
......
......@@ -4,13 +4,13 @@
"""Helper functions for using fp8 tensors as weights"""
import os
from typing import Optional, List, Union
from typing import Optional, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
......@@ -459,18 +459,13 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an object is custom.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if x is None:
return int(os.getenv("QAT_PARAMS", "0")) > 0
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
if x is None or isinstance(x, torch.Tensor):
return False
if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental
return hasattr(x, "custom") and x.custom
......@@ -12,7 +12,7 @@ import numpy as np
import torch
from . import torch_version
from .tensor.quantized_tensor import Quantizer
from .quantized_tensor import Quantizer
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment