Unverified Commit 730fd115 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

Enhance recipe compatibility (#1724)



* Check tensor-recipe compatibility
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Tensor class in recipe, checking for *Base
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Extend recipe __repr__ with recipe_type
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Warn about recipe change
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Enable dynamic recipe change: clear fp8 workspace
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* TE 1.x checkpoint compatibility
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Disable warning for recipe wrappers
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Test recipe change
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use QuantizedTensorBase
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Fix circular import
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Revert previous circular import fix
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* Fix pytorch imports in common
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Let quantizer know about the recipe
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix imports
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>

---------
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7be43390
......@@ -6,21 +6,31 @@ from typing import Iterable, Optional
import pytest
import torch
import warnings
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
_amax_and_scale_update,
get_default_fp8_recipe,
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch import Linear
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
import transformer_engine_torch as tex
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
# FP8 per tensor delayed scaling
......@@ -367,3 +377,96 @@ class TestFP8Recipe:
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
@pytest.mark.parametrize(
"model_init_recipe",
[
pytest.param(
MXFP8BlockScaling(),
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
Float8BlockScaling(),
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
],
)
def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
with fp8_model_init(enabled=True, recipe=model_init_recipe):
linear = Linear(32, 32).cuda()
x = torch.randn(32, 32, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
with pytest.raises(RuntimeError) as excinfo:
_ = linear(x)
assert "Recipe mismatch for " in str(excinfo.value)
@pytest.mark.parametrize(
"target_recipe_class, expected_quantizer_type, available_flag, reason",
[
pytest.param(
MXFP8BlockScaling,
MXFP8Quantizer,
mxfp8_available,
reason_for_no_mxfp8,
id="DelayedScaling->MXFP8BlockScaling",
),
pytest.param(
Float8BlockScaling,
Float8BlockQuantizer,
fp8_block_scaling_available,
reason_for_no_fp8_block_scaling,
id="DelayedScaling->Float8BlockScaling",
),
],
)
def test_dynamic_recipe_update(
self, target_recipe_class, expected_quantizer_type, available_flag, reason
):
if not available_flag:
pytest.skip(reason)
in_features = 32
out_features = 32
batch_size = 32
linear = Linear(in_features, out_features).cuda()
initial_recipe = DelayedScaling()
# Run initial iterations with DelayedScaling
for _ in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, Float8Quantizer)
# Change recipe
target_recipe = target_recipe_class()
# Run subsequent iterations with the target recipe
for i in range(3):
x = torch.randn(batch_size, in_features, device="cuda")
if i == 0:
# Expect a warning on the first iteration with the new recipe
with pytest.warns(UserWarning, match="Recipe type changed"):
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
else:
# No warning expected on subsequent iterations
with warnings.catch_warnings():
warnings.simplefilter("error") # Raise error if unexpected warning occurs
with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
y = linear(x)
loss = y.mean()
loss.backward()
# Final check
for quantizer in linear.quantizers["scaling_fwd"]:
assert isinstance(quantizer, expected_quantizer_type)
......@@ -87,7 +87,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!");
"Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret;
......
......@@ -180,6 +180,7 @@ class DelayedScaling(Recipe):
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
......@@ -245,6 +246,7 @@ class Float8CurrentScaling(Recipe):
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
......@@ -291,7 +293,11 @@ class MXFP8BlockScaling(Recipe):
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)
@dataclass()
......@@ -375,6 +381,7 @@ class Float8BlockScaling(Recipe):
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
......
......@@ -14,7 +14,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
......@@ -459,6 +459,10 @@ class DebugQuantizer(Quantizer):
return True
return False
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
class DebugQuantizedTensor(QuantizedTensorBase):
"""
......
......@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
......@@ -811,6 +811,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"]
......@@ -884,6 +892,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
......@@ -922,6 +932,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager
def prepare_forward(
self,
......@@ -946,6 +969,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, (
......@@ -1346,6 +1370,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().
This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if not self.fp8 and not self.fp8_calibration:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return
recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
compatible_recipe_class = quantizer._get_compatible_recipe()
if compatible_recipe_class is None:
continue
if not isinstance(recipe, compatible_recipe_class):
raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
......
......@@ -4,13 +4,14 @@
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union
import math
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
......@@ -229,6 +230,9 @@ class Float8BlockQuantizer(Quantizer):
# where state from an estimator influences distribution parameters.
pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
......
......@@ -4,13 +4,14 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union
import warnings
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
......@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer=self,
)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling
......@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling
class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......
......@@ -6,12 +6,13 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
......@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......
......@@ -13,6 +13,7 @@ import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
class QuantizedTensorBase:
......@@ -238,6 +239,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return copy.copy(self)
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment