Unverified Commit c9508000 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Decouple Recipe and ScalingMode (#1728)



* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
parent 04add79d
......@@ -14,10 +14,11 @@ from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling,
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
is_fp8_available,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
......@@ -49,7 +50,7 @@ class TestHelper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
......@@ -58,17 +59,23 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
......@@ -78,21 +85,20 @@ class TestFP8Functions(unittest.TestCase):
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(
......@@ -104,21 +110,20 @@ class TestFP8Functions(unittest.TestCase):
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()
with fp8_autocast(
......@@ -130,14 +135,14 @@ class TestFP8Functions(unittest.TestCase):
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
......@@ -23,12 +23,14 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
ScalingMode,
is_fp8_available,
update_collections,
TensorSource,
fp8_autocast,
)
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
from transformer_engine.jax.sharding import MeshResource
@pytest.fixture(autouse=True, scope="function")
......@@ -356,7 +358,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params)
if QuantizeConfig.is_fp8_enabled():
if get_quantize_config().is_fp8_enabled():
for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
......@@ -365,12 +367,15 @@ class BaseRunner:
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
if (
get_quantize_config().get_scaling_mode(TensorSource.X)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
updated_state[0], get_quantize_config().COLLECTION_NAME
)
test_others = update_collections(
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
)
del updated_quantize_meta
del updated_state
......@@ -500,41 +505,33 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
class TestEncoderLayer(BaseTester):
......
......@@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods.
This helper is used in the get_quantize_config().initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
......@@ -28,7 +28,7 @@ from ..quantize import (
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizeConfig,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
......@@ -754,7 +754,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
......@@ -1107,7 +1107,7 @@ def _jax_gemm(
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
......@@ -32,7 +32,14 @@ from ..cpp_extensions import (
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
)
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -350,7 +357,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
else get_quantize_config().COLLECTION_NAME
)
scale = self.variable(
collection_name,
......@@ -363,14 +370,14 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,),
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
......@@ -483,7 +490,7 @@ class DenseGeneral(TransformerEngineBase):
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias:
......@@ -692,7 +699,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
quantizer_set = self.generate_quantizer_set()
fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -743,7 +750,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -1005,7 +1012,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -1088,7 +1095,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
......@@ -1100,7 +1107,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......
......@@ -7,9 +7,11 @@ Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple, Dict, Union, Sequence
from typing import Optional, Tuple, Dict, Union, Sequence, Type
from functools import reduce
import operator
......@@ -26,7 +28,7 @@ from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [
"QuantizeConfig",
"get_quantize_config",
"fp8_autocast",
"is_fp8_available",
"update_collections",
......@@ -34,12 +36,15 @@ __all__ = [
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME",
"TensorSource",
]
_is_fp8_available = None
_reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict]
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
......@@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format):
return jnp.bfloat16, jnp.bfloat16
class TensorSource(Enum):
"""Enumeration for where a tensor's data comes from."""
# Input data
X = 0
# Model parameters
KERNEL = 1
# Gradients in the backward pass
DGRAD = 2
class AmaxComputeAlgo(Enum):
"""Enumeration for AMAX computation algorithms.
......@@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent"
def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
"""Convert recipe.Recipe to ScalingMode.
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!")
class QuantizeConfig:
@dataclass
class BaseQuantizeConfig(ABC):
"""Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats,
......@@ -204,14 +200,13 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
"""
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = "fp8_metas"
COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
......@@ -219,61 +214,82 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
@staticmethod
def is_fp8_enabled():
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
self.FP8_FORMAT = fp8_recipe.fp8_format
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT)
def is_fp8_enabled(self) -> bool:
"""Check if FP8 quantization is enabled.
Returns:
bool: True if quantization is enabled, False otherwise
"""
return QuantizeConfig.INITIALIZED
return self.INITIALIZED
@classmethod
def initialize(cls, fp8_recipe: recipe.Recipe) -> None:
"""Initialize the quantization configuration.
@abstractmethod
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type.
Args:
fp8_recipe: The FP8 recipe to use for initialization
tensor_source: The usage type for which to get the scaling mode.
Returns:
The scaling mode for the specified usage type.
"""
def is_supported(self) -> tuple[bool, str]:
"""Check if this QuantizeConfig class is supported on the available devices.
Returns:
bool: True if the class is supported, False otherwise
str: Reason for being unsupported, if applicable.
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
@classmethod
def finalize(cls) -> None:
"""Reset the quantization configuration to default values."""
cls.INITIALIZED = False
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.INFERENCE_MODE = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
class DelayedScalingQuantizeConfig:
x_scaling_mode = self.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
is_supported, reason = is_fp8_available(scaling_mode=scaling_mode)
if not is_supported:
return is_supported, reason
return True, None
class NoOpQuantizeConfig(BaseQuantizeConfig):
"""Configuration class higher-precision non-quantized operation."""
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize no-op configuration."""
raise NotImplementedError(
"NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
" higher-precision when no quantized recipe is set."
)
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.NO_SCALING
class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize delayed scaling FP8 configuration.
Args:
......@@ -282,6 +298,8 @@ class DelayedScalingQuantizeConfig:
Raises:
AssertionError: If recipe parameters are not supported
"""
super().initialize_from_recipe(fp8_recipe)
assert fp8_recipe.amax_compute_algo in [
"max",
"most_recent",
......@@ -291,71 +309,88 @@ class DelayedScalingQuantizeConfig:
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
string_to_amax_compute_algo = {
"max": AmaxComputeAlgo.MAX,
"most_recent": AmaxComputeAlgo.MOST_RECENT,
}
cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
cls.FP8_2X_ACC_DGRAD = True
cls.FP8_2X_ACC_WGRAD = True
self.FP8_2X_ACC_DGRAD = True
self.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
"""Reset the delayed scaling configuration."""
QuantizeConfig.finalize()
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.DELAYED_TENSOR_SCALING
class CurrentScalingQuantizeConfig:
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
super().initialize_from_recipe(fp8_recipe)
self.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the current scaling configuration."""
QuantizeConfig.finalize()
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.CURRENT_TENSOR_SCALING
class BlockScalingQuantizeConfig:
class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
super().initialize_from_recipe(fp8_recipe)
self.AMAX_HISTORY_LEN = 0
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.MXFP8_1D_SCALING
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
QuantizeConfig.finalize()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by fp8_autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class(
fp8_recipe: recipe.Recipe,
) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration based on the FP8 recipe.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Returns:
The quantization config class corresponding to the given recipe.
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return CurrentScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
@contextmanager
......@@ -404,22 +439,22 @@ def fp8_autocast(
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig
global _QUANTIZE_CONFIG
old_quantize_config = _QUANTIZE_CONFIG
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
try:
with global_shard_guard(mesh_resource):
if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe))
assert fp8_available, reason_for_no_fp8
Config.initialize(fp8_recipe)
_QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported()
assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe)
yield
finally:
Config.finalize()
_QUANTIZE_CONFIG = old_quantize_config
def get_delayed_scaling():
......@@ -437,12 +472,12 @@ def get_delayed_scaling():
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
"max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN,
margin=int(get_quantize_config().MARGIN),
fp8_format=get_quantize_config().FP8_FORMAT,
amax_history_len=get_quantize_config().AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
......@@ -581,6 +616,3 @@ def apply_padding_to_scale_inv(
# Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
......@@ -21,9 +21,10 @@ from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
get_quantize_config,
get_quantize_config_class,
AmaxComputeAlgo,
_get_scaling_mode,
TensorSource,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
......@@ -56,7 +57,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
......@@ -234,7 +235,7 @@ class CurrentScaleQuantizer(Quantizer):
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
......@@ -320,7 +321,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
......@@ -397,7 +398,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value
"""
# 2. Calculate the current scale
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
......@@ -827,12 +828,21 @@ class QuantizerFactory:
@staticmethod
def _create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
x_scaling_mode,
kernel_scaling_mode,
grad_scaling_mode,
fwd_dtype,
bwd_dtype,
is_2x2x,
n_groups,
**kwargs,
) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
Args:
scaling_mode: Scaling mode to use
x_scaling_mode: Scaling mode to use for input tensor 'x'
kernel_scaling_mode: Scaling mode to use for kernel tensor
grad_scaling_mode: Scaling mode to use for gradient tensor
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
......@@ -846,9 +856,9 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling():
if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
if get_quantize_config().INFERENCE_MODE:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -868,12 +878,12 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
)
q_dgrad = QuantizerFactory.create(
1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
......@@ -892,10 +902,10 @@ class QuantizerFactory:
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
......@@ -912,27 +922,44 @@ class QuantizerFactory:
)
if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
scaling_mode = _get_scaling_mode(fp8_recipe)
quantize_config = get_quantize_config_class(fp8_recipe)()
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
elif scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
# TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
elif x_scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
for _ in range(n_quantizer_sets):
q_set.append(
QuantizerFactory._create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
x_scaling_mode=x_scaling_mode,
kernel_scaling_mode=kernel_scaling_mode,
grad_scaling_mode=grad_scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
**kwargs,
)
)
......
......@@ -396,7 +396,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# if get_quantize_config().INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment