Unverified Commit 730fd115 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Enhance recipe compatibility (#1724)



* Check tensor-recipe compatibility
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Tensor class in recipe, checking for *Base
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Extend recipe __repr__ with recipe_type
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Warn about recipe change
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Enable dynamic recipe change: clear fp8 workspace
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* TE 1.x checkpoint compatibility
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Disable warning for recipe wrappers
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Test recipe change
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use QuantizedTensorBase
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Fix circular import
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Revert previous circular import fix
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Fix pytorch imports in common
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Let quantizer know about the recipe
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix imports
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7be43390
...@@ -6,21 +6,31 @@ from typing import Iterable, Optional ...@@ -6,21 +6,31 @@ from typing import Iterable, Optional
import pytest import pytest
import torch import torch
import warnings
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager, FP8GlobalStateManager,
_amax_and_scale_update, _amax_and_scale_update,
get_default_fp8_recipe, fp8_model_init,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex import transformer_engine_torch as tex
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
# FP8 per tensor delayed scaling # FP8 per tensor delayed scaling
...@@ -367,3 +377,96 @@ class TestFP8Recipe: ...@@ -367,3 +377,96 @@ class TestFP8Recipe:
) )
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
@pytest.mark.parametrize(
"model_init_recipe",
[
pytest.param(
MXFP8BlockScaling(),
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
Float8BlockScaling(),
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
],
)
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda()
x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo:
_ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value)
@pytest.mark.parametrize(
"target_recipe_class, expected_quantizer_type, available_flag, reason",
[
pytest.param(
MXFP8BlockScaling,
MXFP8Quantizer,
mxfp8_available,
reason_for_no_mxfp8,
id="DelayedScaling->MXFP8BlockScaling",
),
pytest.param(
Float8BlockScaling,
Float8BlockQuantizer,
fp8_block_scaling_available,
reason_for_no_fp8_block_scaling,
id="DelayedScaling->Float8BlockScaling",
),
],
)
def test_dynamic_recipe_update(
self, target_recipe_class, expected_quantizer_type, available_flag, reason
):
if not available_flag:
pytest.skip(reason)
in_features = 32
out_features = 32
batch_size = 32
linear = Linear(in_features, out_features).cuda()
initial_recipe = DelayedScaling()
# Run initial iterations with DelayedScaling
for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, Float8Quantizer)
# Change recipe
target_recipe = target_recipe_class()
# Run subsequent iterations with the target recipe
for i in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
if i == 0:
# Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
else:
# No warning expected on subsequent iterations
with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
# Final check
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
...@@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
A.scaling_mode == B.scaling_mode || A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!"); "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret; GemmParam ret;
......
...@@ -180,6 +180,7 @@ class DelayedScaling(Recipe): ...@@ -180,6 +180,7 @@ class DelayedScaling(Recipe):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, " f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, " f"amax_history_len={self.amax_history_len}, "
...@@ -245,6 +246,7 @@ class Float8CurrentScaling(Recipe): ...@@ -245,6 +246,7 @@ class Float8CurrentScaling(Recipe):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
...@@ -291,7 +293,11 @@ class MXFP8BlockScaling(Recipe): ...@@ -291,7 +293,11 @@ class MXFP8BlockScaling(Recipe):
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str: def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)
@dataclass() @dataclass()
...@@ -375,6 +381,7 @@ class Float8BlockScaling(Recipe): ...@@ -375,6 +381,7 @@ class Float8BlockScaling(Recipe):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -459,6 +459,10 @@ class DebugQuantizer(Quantizer): ...@@ -459,6 +459,10 @@ class DebugQuantizer(Quantizer):
return True return True
return False return False
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
class DebugQuantizedTensor(QuantizedTensorBase): class DebugQuantizedTensor(QuantizedTensorBase):
""" """
......
...@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase ...@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
...@@ -811,6 +811,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -811,6 +811,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)
# Load extra items # Load extra items
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"] self.fp8_meta["recipe"] = state["recipe"]
...@@ -884,6 +892,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -884,6 +892,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
...@@ -922,6 +932,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -922,6 +932,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
self, self,
...@@ -946,6 +969,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -946,6 +969,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, ( assert self.fp8_meta["recipe"].reduce_amax, (
...@@ -1346,6 +1370,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1346,6 +1370,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
self.name = f"Layer_{TEDebugState.get_layer_count()}" self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().
This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if not self.fp8 and not self.fp8_calibration:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return
recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
compatible_recipe_class = quantizer._get_compatible_recipe()
if compatible_recipe_class is None:
continue
if not isinstance(recipe, compatible_recipe_class):
raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def _turn_off_unsupported_features_in_debug(self): def _turn_off_unsupported_features_in_debug(self):
if ( if (
getattr(self, "ub_bulk_wgrad", False) getattr(self, "ub_bulk_wgrad", False)
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
"""Tensor class with FP8 data quantized with NxN tiles""" """Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable from typing import Optional, Tuple, Iterable, Union
import math import math
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
...@@ -229,6 +230,9 @@ class Float8BlockQuantizer(Quantizer): ...@@ -229,6 +230,9 @@ class Float8BlockQuantizer(Quantizer):
# where state from an estimator influences distribution parameters. # where state from an estimator influences distribution parameters.
pass pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable from typing import Optional, Tuple, Iterable, Union
import warnings import warnings
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
...@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer): ...@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer=self, quantizer=self,
) )
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
class Float8CurrentScalingQuantizer(Quantizer): class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling """Builder class for FP8 tensors with per-tensor current scaling
...@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction""" """Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group) return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
...@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer): ...@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8? # TODO(ksivamani): No calibration needed for mxfp8?
pass pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase: class QuantizedTensorBase:
...@@ -238,6 +239,10 @@ class Quantizer(abc.ABC): ...@@ -238,6 +239,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy""" """Create shallow copy"""
return copy.copy(self) return copy.copy(self)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
class _QuantizeFunc(torch.autograd.Function): class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment