"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "f7389f4763c37579d249d0f9d80917e2ecfc4ead"
Unverified Commit dfacd9f7 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Use Quantization API for reference NVFP4 recipe (#2259)



* Fix update_quantized in ref nvfp4 quantizer
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Subclass quantization API
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Use recipe.Custom and quantizer factories for reference NVFP4
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Linter fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5ec0f33b
......@@ -21,9 +21,12 @@ from transformer_engine.common.recipe import (
Format,
Recipe,
QParams,
CustomRecipe,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
from run_layer_with_overlap import _compare_tensors
......@@ -48,6 +51,52 @@ def nvfp4_rht_and_2d_quantization():
return nvfp4_recipe
def get_nvfp4_quantizer_factory():
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization
enabled.
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for input
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16), # 2D quantization for weight
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for grad_output
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
return factory
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "nvfp4":
......@@ -55,16 +104,12 @@ def quantization_recipe() -> Recipe:
raise ValueError(f"Unsupported quantization: {QUANTIZATION}")
def setup_environment_for_reference():
def quantization_reference_recipe() -> Recipe:
"""Create reference recipe using CustomRecipe with NVFP4 quantizer factory."""
if QUANTIZATION == "nvfp4":
os.environ["QAT_PARAMS"] = "9003"
else:
raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}")
def cleanup_environment():
if "QAT_PARAMS" in os.environ:
del os.environ["QAT_PARAMS"]
nvfp4_ref_factory = get_nvfp4_quantizer_factory()
return CustomRecipe(qfactory=nvfp4_ref_factory)
raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}")
def main(argv=None, namespace=None):
......@@ -478,8 +523,8 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
)
# run the reference
setup_environment_for_reference()
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
x,
w,
......@@ -494,8 +539,6 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
run_num_steps=run_num_steps,
enable_weight_cache=enable_weight_cache,
)
# Clean up env
cleanup_environment()
# compare results, zero tolerance
if WORLD_RANK == 0:
......@@ -673,8 +716,8 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
)
# run the reference
setup_environment_for_reference()
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
reference_recipe = quantization_reference_recipe()
with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
TestDistributedLayerNormLinearBase.run_layernorm_linear(
x,
......@@ -690,8 +733,6 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
enable_weight_cache=False,
)
)
# Clean up env
cleanup_environment()
# compare results, zero tolerance
if WORLD_RANK == 0:
......
......@@ -9,7 +9,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
......
......@@ -2,13 +2,14 @@
#
# See LICENSE for license information.
import os
import pytest
import torch
import transformer_engine as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common import recipe
from transformer_engine.pytorch.experimental import quantization_nvfp4
from transformer_engine.pytorch.experimental import utils
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
......@@ -65,20 +66,54 @@ class GetRecipes:
return GetRecipes.nvfp4_vanilla()
def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False):
if with_rht and with_2d_quantization:
os.environ["QAT_PARAMS"] = "9003"
elif with_rht:
os.environ["QAT_PARAMS"] = "960109"
elif with_2d_quantization:
os.environ["QAT_PARAMS"] = "9002"
else:
os.environ["QAT_PARAMS"] = "6010"
def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bool = False):
"""
Create a quantizer factory for NVFP4 reference implementation.
This factory returns NVFP4QuantizerRef instances based on the role and configuration.
Used with CustomRecipe to create reference quantizers.
Args:
with_rht: Whether to enable random Hadamard transform
with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights)
Returns:
A factory function that takes a role string and returns a quantizer instance
"""
def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
def cleanup_environment():
if "QAT_PARAMS" in os.environ:
del os.environ["QAT_PARAMS"]
return factory
def reset_rng_states():
......@@ -113,7 +148,6 @@ def check_nvfp4_module_versus_reference(
seq_len = 128
# Create both modules with identical initialization
cleanup_environment()
reset_rng_states()
# Create native module
......@@ -138,7 +172,6 @@ def check_nvfp4_module_versus_reference(
raise ValueError(f"Unsupported module class: {module_class}")
# Create reference module with same weights
setup_environment_for_reference(with_rht, with_2d_quantization)
reset_rng_states()
# Create reference module
......@@ -174,7 +207,10 @@ def check_nvfp4_module_versus_reference(
if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
# Create recipes for native and reference implementations
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization)
nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory)
# Training loop comparison
native_outputs = []
......@@ -196,17 +232,13 @@ def check_nvfp4_module_versus_reference(
grad_output = grad_output_val.clone().detach()
# Native forward/backward
cleanup_environment()
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
# enable weight cache by giving is_first_microbatch
y_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
setup_environment_for_reference(with_rht, with_2d_quantization)
with fp8_autocast(
enabled=True, fp8_recipe=nvfp4_recipe
): # Exact recipe does not play a role here
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe):
y_ref = ref_module(x_ref)
y_ref.backward(grad_output)
......@@ -295,9 +327,6 @@ def check_nvfp4_module_versus_reference(
msg=f"Bias gradient mismatch at step {step}",
)
# Clean up
cleanup_environment()
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
......@@ -362,7 +391,6 @@ def check_nvfp4_layernorm_linear_versus_reference(
seq_len = 128
# Create both modules with identical initialization
cleanup_environment()
reset_rng_states()
# Native module
......@@ -377,7 +405,6 @@ def check_nvfp4_layernorm_linear_versus_reference(
)
# Reference module
setup_environment_for_reference(with_rht, with_2d_quantization)
reset_rng_states()
ref_module = te.pytorch.LayerNormLinear(
in_features=in_features,
......@@ -405,7 +432,10 @@ def check_nvfp4_layernorm_linear_versus_reference(
if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None:
ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
# Create recipes for native and reference implementations
nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization)
nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory)
native_outputs = []
ref_outputs = []
......@@ -426,14 +456,12 @@ def check_nvfp4_layernorm_linear_versus_reference(
grad_output = grad_output_val.clone().detach()
# Native forward/backward
cleanup_environment()
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
y_native.backward(grad_output)
# Reference forward/backward
setup_environment_for_reference(with_rht, with_2d_quantization)
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe):
y_ref, ln_out_ref = ref_module(x_ref)
y_ref.backward(grad_output)
......@@ -515,8 +543,6 @@ def check_nvfp4_layernorm_linear_versus_reference(
msg=f"Bias gradient mismatch at step {step}",
)
cleanup_environment()
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
......
......@@ -12,7 +12,7 @@ from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
......
......@@ -20,7 +20,7 @@ from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
NVFP4Quantizer,
)
from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
......
......@@ -546,6 +546,8 @@ class DotProductAttention(TransformerEngineBaseModule):
# global recipe set in fp8_autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.custom():
return
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
......
......@@ -3,8 +3,3 @@
# See LICENSE for license information.
"""Experimental features and APIs."""
from .config import set_qlinear_params, get_experimental_quantizers
__all__ = ["set_qlinear_params", "get_experimental_quantizers"]
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Config API for experimental middleware between Transformer Engine and Kitchen."""
import dataclasses
import enum
import os
from typing import Optional
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import quantization_microblock_ref
from transformer_engine.pytorch.experimental.quantization import MMParams
@dataclasses.dataclass()
class QLinearParams:
"""Quantization parameters of linear layer.
Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors.
"""
x_quantizer: Optional[quantization.ExperimentalQuantizer] = None
w_quantizer: Optional[quantization.ExperimentalQuantizer] = None
g_quantizer: Optional[quantization.ExperimentalQuantizer] = None
mm_fprop: Optional[MMParams] = None
mm_dgrad: Optional[MMParams] = None
mm_wgrad: Optional[MMParams] = None
@enum.unique
class QuantizeRecipe(enum.Enum):
"""Pre-defined quantization recipes for linear layers."""
NON_QUANTIZE = "non_quantize"
NVFP4_REF = "nvfp4_ref"
NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only"
NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only"
NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization"
def get_qlinear_params_from_predefined(
recipe: QuantizeRecipe,
) -> Optional[QLinearParams]:
"""Get quantization parameters for linear layer based on recipe."""
if recipe == QuantizeRecipe.NON_QUANTIZE:
return None
if recipe == QuantizeRecipe.NVFP4_REF:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
)
raise ValueError(f"Unsupported quantize recipe: {recipe}")
def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]:
"""Load quantization options from Kitchen to Transformer Engine.
TODO(etsykunov): Confirm docstring is correct.
"""
assert qat_params_idx > 0, "QAT_PARAMS is not set."
if qat_params_idx == 6010:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF)
if qat_params_idx == 960109:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY)
if qat_params_idx == 9002:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY)
if qat_params_idx == 9003:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION)
raise ValueError(f"Unsupported QAT params index: {qat_params_idx}")
def set_qlinear_params(
qlinear_params: Optional[QLinearParams] = None,
layer_number: Optional[int] = None,
layer_name: Optional[str] = None,
) -> Optional[QLinearParams]:
"""Set quantization parameters based on configuration.
Args:
qlinear_params: Quantization parameters. If None, loaded from environment.
layer_number: The numerical index of this layer in the model structure.
layer_name: The name for this layer.
Returns:
QLinearParams: The finalized quantization parameters for this layer.
"""
if qlinear_params is None:
qat_params_idx = int(os.getenv("QAT_PARAMS", "0"))
if qat_params_idx == 0:
return None
return get_qlinear_params_from_qat_params(qat_params_idx)
# Apply layer-specific overrides
if layer_number is not None:
raise NotImplementedError("Layer-specific overrides are not supported yet.")
if layer_name is not None:
raise NotImplementedError("Layer-specific overrides are not supported yet.")
return qlinear_params
def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams):
"""Replacement of _get_quantizers() in TE modules."""
if not fp8:
raise ValueError("FP8 is required to be enabled for experimental quantization.")
input_quantizer = qlinear_params.x_quantizer
weight_quantizer = qlinear_params.w_quantizer
output_quantizer = None
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = qlinear_params.g_quantizer
return (
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
......@@ -11,14 +11,14 @@ import torch
from transformer_engine.pytorch.experimental.quantization import (
MMParams,
GEMMType,
ExperimentalQuantizedTensor,
)
from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_experimental
def experimental_gemm(
A: ExperimentalQuantizedTensor,
B: ExperimentalQuantizedTensor,
A: QuantizedTensorStorage,
B: QuantizedTensorStorage,
workspace: torch.Tensor, # pylint: disable=unused-argument
out_dtype: Optional[torch.dtype] = None,
quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument
......@@ -32,9 +32,7 @@ def experimental_gemm(
grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert isinstance(A, ExperimentalQuantizedTensor) and isinstance(
B, ExperimentalQuantizedTensor
), "A and B must be ExperimentalQuantizedTensor instances"
assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors"
A, B = B, A
......@@ -51,14 +49,14 @@ def experimental_gemm(
gemm_type = GEMMType.FPROP
# Extract quantizer from QuantizedTensor to get qgemm logic
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer?
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer?
quantizer = None
if hasattr(A, "quantizer") and A.quantizer is not None:
quantizer = A.quantizer
elif hasattr(B, "quantizer") and B.quantizer is not None:
quantizer = B.quantizer
if hasattr(A, "_quantizer") and A._quantizer is not None:
quantizer = A._quantizer
elif hasattr(B, "_quantizer") and B._quantizer is not None:
quantizer = B._quantizer
else:
raise ValueError("No quantizer found in QuantizedETensor objects")
raise ValueError("No quantizer found in QuantizedTensor objects")
# Create MMParams
m_params = MMParams(
......
......@@ -5,17 +5,11 @@
"""Quantization API for experimental middleware between Transformer Engine and Kitchen."""
from __future__ import annotations
import abc
import dataclasses
import enum
from typing import Iterable, Optional, Tuple, Union
import torch
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.experimental import utils
@enum.unique
class GEMMType(enum.Enum):
......@@ -33,171 +27,3 @@ class MMParams:
out_dtype: torch.dtype | None = None
# Use split accumulator for more accurate FP8 GEMM
use_split_accumulator: bool = True
@dataclasses.dataclass
class ExperimentalQuantizedTensor(QuantizedTensorStorage):
"""Base class for experimental quantized tensor containers.
An experimental container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
quantizer: ExperimentalQuantizer
Builder class for quantized tensor.
"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
global_amax_row: Optional[torch.Tensor] = None
global_amax_col: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None
original_shape: Optional[Tuple[int, ...]] = None
quantizer: Optional[ExperimentalQuantizer] = None
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return True
def get_quantizer(self) -> ExperimentalQuantizer:
"""Get builder for QuantizedExperimentalTensor
Quantizer can be used for in-place operations.
"""
if self.quantizer is not None:
return self.quantizer
raise ValueError("Quantizer is not set")
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
def dequantize(self, *args, **kwargs) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
class ExperimentalQuantizer(Quantizer):
"""Experimental Quantizer class
Defines the interface for experimental quantizers.
"""
def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return True
@abc.abstractmethod
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: MMParams,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: GEMMType = GEMMType.FPROP,
qresult_x: ExperimentalQuantizedTensor | None = None,
qresult_w: ExperimentalQuantizedTensor | None = None,
) -> torch.Tensor:
"""Quantized GEMM interface."""
def dequantize(self, *args, **kwargs) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def update_quantized(self, *args, **kwargs) -> torch.Tensor:
"""Update the quantized tensor with the given tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized function"
)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensorStorage:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function"
)
def calibrate(self, tensor: torch.Tensor) -> None:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement calibrate function"
)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe function"
)
......@@ -2,18 +2,49 @@
#
# See LICENSE for license information.
"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen."""
"""NVFP4 recipe reference implementation."""
from typing import Optional, Tuple
import dataclasses
from typing import Optional, Tuple, Union
import torch
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.experimental.quantization import (
ExperimentalQuantizedTensor,
ExperimentalQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
def nvfp4_ref_rht_2d_quantizer_factory(role):
"""
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and fp8_autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with fp8_autocast(fp8_recipe=custom_recipe):
output = model(input)
"""
if role == "linear_input":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)
if role == "linear_weight":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
)
if role == "linear_grad_output":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)
return None
def cast_to_fp4x2(x):
......@@ -156,8 +187,89 @@ def high_precision_gemm_ref(
return y_ref
class NVFP4TensorRef(ExperimentalQuantizedTensor):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen"""
@dataclasses.dataclass
class NVFP4TensorRef(QuantizedTensorStorage):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen.
Custom container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
_quantizer: Quantizer
Builder class for quantized tensor.
"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
global_amax_row: Optional[torch.Tensor] = None
global_amax_col: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None
original_shape: Optional[Tuple[int, ...]] = None
_quantizer: Optional[Quantizer] = None
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return True
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
def __repr__(self):
return (
......@@ -165,47 +277,10 @@ class NVFP4TensorRef(ExperimentalQuantizedTensor):
f"dtype={self.dtype}, "
f"device={self.device}, "
f"quant_dtype={self.quant_dtype}, "
f"data={self.dequantize(dtype=self.dtype)}, "
f"original_shape={self.original_shape}"
")"
)
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> ExperimentalQuantizedTensor:
"""In-place update of quantized data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, ExperimentalQuantizedTensor):
return self.quantize_(tensor.dequantize(), noop_flag=noop_flag)
self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from quantized tensor
"""
if dtype is None:
dtype = self.dtype
# Ignore data_t for now
assert self.data is not None, "QuantizedTensor has no valid tensor data"
assert self.scale is not None, "QuantizedTensor has no valid scale"
tensor_data = self.data
tensor_scale = self.scale
# Dispatch to the quantizer
return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
......@@ -224,10 +299,10 @@ class NVFP4TensorRef(ExperimentalQuantizedTensor):
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
raise RuntimeError("Cannot generate FP4 data, even from FP4 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
raise RuntimeError("FP4 data is required to generate FP4 data transpose")
self._create_transpose()
# Delete data that is not required
......@@ -262,7 +337,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
)
class NVFP4QuantizerRef(ExperimentalQuantizer):
class NVFP4QuantizerRef(Quantizer):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen"""
def __init__(
......@@ -277,6 +352,8 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
with_random_sign_mask: bool = True,
):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
self.dtype = dtype
self.pow_2_scales = pow_2_scales
self.eps = eps
......@@ -284,6 +361,11 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
self.with_rht = with_rht
self.with_random_sign_mask = with_random_sign_mask
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return True
@staticmethod
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
......@@ -500,7 +582,7 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
- sx: scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise
- global_amax: global amax tensor
- global_amax_row, global_amax_col: global amax tensors
"""
if self.pow_2_scales:
assert self.quant_tile_shape == (
......@@ -607,25 +689,25 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
dtype=tensor.dtype,
device=tensor.device,
quant_dtype=self.dtype,
quantizer=self,
_quantizer=self,
original_shape=original_shape,
)
def update_quantized(
self,
src: torch.Tensor,
dst: ExperimentalQuantizedTensor,
dst: QuantizedTensorStorage,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> ExperimentalQuantizedTensor:
) -> QuantizedTensorStorage:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
dst: QuantizedTensorStorage
Destination QuantizedTensorStorage to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
......@@ -642,14 +724,15 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
if src.ndim > 2:
src = src.view(-1, src.shape[-1])
qx, sx, qx_t, sx_t, global_amax = self._quantize(src)
qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(src)
# Update the destination with new data
dst.data = qx
dst.scale = sx
dst.data_t = qx_t
dst.scale_t = sx_t
dst.global_amax = global_amax
dst.global_amax_row = global_amax_row
dst.global_amax_col = global_amax_col
dst.dtype = src.dtype
dst.quant_dtype = self.dtype
dst.original_shape = original_shape
......@@ -665,9 +748,7 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
"""
return False
def transpose_qresult(
self, qresult: quantization.ExperimentalQuantizedTensor
) -> quantization.ExperimentalQuantizedTensor:
def transpose_qresult(self, qresult: QuantizedTensorStorage) -> QuantizedTensorStorage:
"""Convert row-wise data to column-wise data (?)
TODO(etsykunov): Confirm docstring is correct.
......@@ -687,17 +768,11 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
"""
raise NotImplementedError("Not implemented yet")
def dequantize(
self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError("Not implemented yet")
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: quantization.MMParams,
m_params: quantization.MMParams, # pylint: disable=unused-argument
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
......@@ -705,9 +780,10 @@ class NVFP4QuantizerRef(ExperimentalQuantizer):
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP,
qresult_x: quantization.ExperimentalQuantizedTensor | None = None,
qresult_w: quantization.ExperimentalQuantizedTensor | None = None,
qresult_x: QuantizedTensorStorage | None = None,
qresult_w: QuantizedTensorStorage | None = None,
) -> torch.Tensor:
"""Python implementation of microblock FP4 GEMM."""
assert bias is None, "Bias is implemented for FP4 GEMM."
high_precision_x = cast_from_fp4x2(qx, out_dtype)
......
......@@ -11,10 +11,8 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from .. import cpp_extensions as tex
from .. import experimental
from ..constants import TE_DType
from ..export import is_in_onnx_export_mode
from ..tensor.utils import is_experimental
from ..utils import get_default_init_method
......@@ -172,32 +170,6 @@ def noop_cat(
return _NoopCatFunc.apply(dim, *tensors)
def get_module_quantizers(
module: torch.nn.Module,
fp8_output: bool,
fp8_grad: bool,
debug: bool,
):
"""Return the 6-tuple of quantizers for a module in a centralized way.
Routing policy:
- If experimental quantization is enabled via environment and module.fp8 is True,
return experimental quantizers.
- Otherwise, return the module's own quantizers (debug or regular).
"""
if getattr(module, "fp8", False) and is_experimental():
# TODO(etsykunov): Quantizer instantiation should be better
# done in the module's constructor
qlinear_params = experimental.config.set_qlinear_params()
if qlinear_params is not None:
return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params)
if not debug:
return module._get_quantizers(fp8_output, fp8_grad)
return module._get_debug_quantizers(fp8_output, fp8_grad)
@dataclasses.dataclass
class _ParameterInitMeta:
"""
......
......@@ -55,7 +55,7 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
......@@ -1541,7 +1541,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
......
......@@ -25,7 +25,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, WeightGradStore, get_module_quantizers
from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
......@@ -1428,7 +1428,11 @@ class Linear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment