Unverified Commit ce2e8bd1 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Decouple python quantization classes and refactor custom quantization (#2276)



* rename experimental -> custom_recipes
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Decouple python base classes (api)
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* update test_custom_recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename experimental -> custom
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix import
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* Update tests/pytorch/test_custom_recipe.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>

* quantization_base -> quantized_tensor rename
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <e.tsykunov@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2712bb95
......@@ -45,7 +45,8 @@ from transformer_engine.pytorch.utils import (
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
prepare_for_saving,
restore_from_saved,
)
......
......@@ -22,8 +22,8 @@ from transformer_engine.common.recipe import (
)
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors
......@@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference
QUANTIZATION options: nvfp4 <=> custom nvfp4 as a reference
"""
params_dtype = torch.bfloat16
use_bias = kwargs.get("bias", True)
......
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch import (
Float8Tensor,
)
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......
......@@ -14,7 +14,7 @@ import transformer_engine.pytorch as te
Distributed numerics tests
This numerical test aims for zero tolerance test for absolute confidence in numerics.
In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
In the case of NVFP4, with the custom NVFP4 quantization, we matched bitwise
result with the native silicon. For distrbuted test cases, we can do the same by thing
by comparing BF16 AG results with the low precision AG results at layer level.
"""
......
......@@ -8,8 +8,8 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
......@@ -6,8 +6,8 @@ import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
......@@ -7,10 +7,10 @@ import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......
......@@ -12,10 +12,10 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.common.recipe import NVFP4BlockScaling
import pytest
import torch
......
......@@ -17,6 +17,48 @@ from transformer_engine.pytorch import (
Float8CurrentScalingQuantizer,
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
nvfp4_ref_rht_2d_quantizer_factory,
)
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear"])
def test_custom_recipe_sanity_modules_nvfp4(module_type):
"""Test modules with NVFP4 custom recipe support"""
available, reason = te.is_fp8_available(return_reason=True)
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
# Simple linear layer with dims divisible by 16
in_features = 64
out_features = 64
batch = 32
if module_type == "Linear":
model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda()
elif module_type == "LayerNormLinear":
model = LayerNormLinear(
in_features, out_features, params_dtype=torch.bfloat16, bias=False
).cuda()
else: # OpsLinear
model = te_ops.Linear(
in_features, out_features, device="cuda", dtype=torch.bfloat16, bias=False
)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Use NVFP4 quantizer factory
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
# Execute with custom recipe
with autocast(enabled=True, recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
# Basic sanity: gradients exist
assert inp.grad is not None
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
......
......@@ -15,7 +15,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
QuantizedTensor,
Quantizer,
QuantizedTensorStorage,
......
......@@ -66,24 +66,24 @@ from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.quantized_tensor import Quantizer
from transformer_engine.pytorch.quantized_tensor import prepare_for_saving
from transformer_engine.pytorch.quantized_tensor import restore_from_saved
from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor import NVFP4Quantizer
from transformer_engine.pytorch.tensor import QuantizedTensorStorage
from transformer_engine.pytorch.tensor import Float8TensorStorage
from transformer_engine.pytorch.tensor import MXFP8TensorStorage
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage
from transformer_engine.pytorch.tensor import NVFP4TensorStorage
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import Float8Tensor
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import prepare_for_saving
from transformer_engine.pytorch.tensor import restore_from_saved
try:
torch._dynamo.config.error_on_nested_jit_trace = False
......
......@@ -24,7 +24,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
......
......@@ -21,7 +21,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
)
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.constants import (
dist_group_type,
......@@ -33,7 +33,7 @@ from transformer_engine.pytorch.distributed import (
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
......
......@@ -15,7 +15,7 @@ from transformer_engine_torch import (
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend,
)
from ..tensor.quantized_tensor import Quantizer
from ..quantized_tensor import Quantizer
__all__ = [
......
......@@ -11,10 +11,10 @@ import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from ..tensor.quantized_tensor import Quantizer
from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [
......@@ -79,9 +79,9 @@ def general_gemm(
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
# If A or B are custom tensors -> dispatch to quantizers's qgemm implementation
if is_custom(A) or is_custom(B):
return custom_gemm(
A,
B,
workspace,
......
......@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .tensor.quantized_tensor import QuantizedTensorStorage
from .quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
......
......@@ -2,21 +2,21 @@
#
# See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""
"""GEMM API that enables custom GEMM logic for custom quantization recipes."""
from typing import Iterable, Optional
import torch
from transformer_engine.pytorch.experimental.quantization import (
from transformer_engine.pytorch.custom_recipes.quantization import (
MMParams,
GEMMType,
)
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_experimental
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_custom
def experimental_gemm(
def custom_gemm(
A: QuantizedTensorStorage,
B: QuantizedTensorStorage,
workspace: torch.Tensor, # pylint: disable=unused-argument
......@@ -32,7 +32,7 @@ def experimental_gemm(
grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors"
assert is_custom(A) and is_custom(B), "A and B must be custom tensors"
A, B = B, A
......
......@@ -9,9 +9,9 @@ from typing import Optional, Tuple, Union
import torch
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.custom_recipes import quantization
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
def nvfp4_ref_rht_2d_quantizer_factory(role):
......@@ -229,8 +229,8 @@ class NVFP4TensorRef(QuantizedTensorStorage):
_quantizer: Optional[Quantizer] = None
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
def custom(self) -> bool:
"""Flag to indicate this quantized tensor is custom."""
return True
def prepare_for_saving(
......@@ -362,8 +362,8 @@ class NVFP4QuantizerRef(Quantizer):
self.with_random_sign_mask = with_random_sign_mask
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
def custom(self) -> bool:
"""Flag to indicate this quantizer is custom."""
return True
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment