"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "81429b80388b3e0b3b7a746ed3568694a2fdd5eb"
Unverified Commit c9508000 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Decouple Recipe and ScalingMode (#1728)



* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
parent 04add79d
...@@ -14,10 +14,11 @@ from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, ...@@ -14,10 +14,11 @@ from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling,
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
QuantizeConfig, get_quantize_config,
is_fp8_available, is_fp8_available,
ScalingMode, ScalingMode,
update_collections, update_collections,
TensorSource,
) )
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
...@@ -49,7 +50,7 @@ class TestHelper(unittest.TestCase): ...@@ -49,7 +50,7 @@ class TestHelper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase): class TestFP8Functions(unittest.TestCase):
def _check_default_state(self): def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled()) self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test): def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin) self.assertTrue(ref.margin == test.margin)
...@@ -58,17 +59,23 @@ class TestFP8Functions(unittest.TestCase): ...@@ -58,17 +59,23 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test): def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
)
def _compare_mxfp8_scaling(self, test): def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin) self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING) for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self): def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
...@@ -78,21 +85,20 @@ class TestFP8Functions(unittest.TestCase): ...@@ -78,21 +85,20 @@ class TestFP8Functions(unittest.TestCase):
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state() self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self): def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
with fp8_autocast( with fp8_autocast(
...@@ -104,21 +110,20 @@ class TestFP8Functions(unittest.TestCase): ...@@ -104,21 +110,20 @@ class TestFP8Functions(unittest.TestCase):
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
self._check_default_state() self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs) self._compare_current_scaling(cs)
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self): def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state() self._check_default_state()
with fp8_autocast( with fp8_autocast(
...@@ -130,14 +135,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -130,14 +135,14 @@ class TestFP8Functions(unittest.TestCase):
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs) self._compare_mxfp8_scaling(bs)
self._check_default_state() self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs) self._compare_mxfp8_scaling(bs)
self._check_default_state() self._check_default_state()
...@@ -23,12 +23,14 @@ from utils import EncoderLayer as RefEncoderLayer ...@@ -23,12 +23,14 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
QuantizeConfig, get_quantize_config,
ScalingMode, ScalingMode,
is_fp8_available, is_fp8_available,
update_collections, update_collections,
TensorSource,
fp8_autocast,
) )
from transformer_engine.jax.sharding import MeshResource, global_shard_guard from transformer_engine.jax.sharding import MeshResource
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
...@@ -356,7 +358,7 @@ class BaseRunner: ...@@ -356,7 +358,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params) ref_params, test_params = self._sync_params(ref_params, test_params)
if QuantizeConfig.is_fp8_enabled(): if get_quantize_config().is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs, inputs,
...@@ -365,12 +367,15 @@ class BaseRunner: ...@@ -365,12 +367,15 @@ class BaseRunner:
test_others, test_others,
test_layer, test_layer,
) )
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: if (
get_quantize_config().get_scaling_mode(TensorSource.X)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME updated_state[0], get_quantize_config().COLLECTION_NAME
) )
test_others = update_collections( test_others = update_collections(
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
) )
del updated_quantize_meta del updated_quantize_meta
del updated_state del updated_state
...@@ -500,41 +505,33 @@ class BaseTester: ...@@ -500,41 +505,33 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs): def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward""" """Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. # Ensure FP8 disabled.
with global_shard_guard( # Empty MeshResource is used as we are running on a single device
MeshResource() with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype) self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs): def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward""" """Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. # Ensure FP8 disabled.
with global_shard_guard( # Empty MeshResource is used as we are running on a single device
MeshResource() with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype) self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) # Empty MeshResource is used as we are running on a single device
with global_shard_guard( with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) # Empty MeshResource is used as we are running on a single device
with global_shard_guard( with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
class TestEncoderLayer(BaseTester): class TestEncoderLayer(BaseTester):
......
...@@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F ...@@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
""" """
Helper function to manage primitive states by name without modifying environment variables. Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods. This helper is used in the get_quantize_config().initialize() methods.
Args: Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
...@@ -28,7 +28,7 @@ from ..quantize import ( ...@@ -28,7 +28,7 @@ from ..quantize import (
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeConfig, get_quantize_config,
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
...@@ -754,7 +754,7 @@ def _te_gemm( ...@@ -754,7 +754,7 @@ def _te_gemm(
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
...@@ -1107,7 +1107,7 @@ def _jax_gemm( ...@@ -1107,7 +1107,7 @@ def _jax_gemm(
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = ( precision = (
jax.lax.Precision.HIGHEST jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT else jax.lax.Precision.DEFAULT
) )
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
...@@ -32,7 +32,14 @@ from ..cpp_extensions import ( ...@@ -32,7 +32,14 @@ from ..cpp_extensions import (
jax_scaled_masked_softmax, jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax, jax_scaled_upper_triang_masked_softmax,
) )
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
)
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -350,7 +357,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -350,7 +357,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME else get_quantize_config().COLLECTION_NAME
) )
scale = self.variable( scale = self.variable(
collection_name, collection_name,
...@@ -363,14 +370,14 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -363,14 +370,14 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
collection_name, collection_name,
f"{quantizer_name}{postfix}_amax_history", f"{quantizer_name}{postfix}_amax_history",
jnp.zeros, jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,), (get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32, jnp.float32,
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance( if get_quantize_config().get_scaling_mode(
fp8_recipe, recipe.DelayedScaling TensorSource.X
): ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x") x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel") kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad") grad_meta = generate_quantize_meta("grad")
...@@ -483,7 +490,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -483,7 +490,7 @@ class DenseGeneral(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
...@@ -692,7 +699,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -692,7 +699,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set()
fuse_layernorm = ( fuse_layernorm = (
QuantizeConfig.is_fp8_enabled() get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -743,7 +750,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -743,7 +750,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape, kernel_shape,
self.dtype, self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -1005,7 +1012,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1005,7 +1012,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision # TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented # when NoOpQuantizer and Tensor are implemented
fuse_layernorm = ( fuse_layernorm = (
QuantizeConfig.is_fp8_enabled() get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -1088,7 +1095,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1088,7 +1095,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
...@@ -1100,7 +1107,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1100,7 +1107,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape, kernel_2_shape,
self.dtype, self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
......
...@@ -7,9 +7,11 @@ Config module for quantization metadata management ...@@ -7,9 +7,11 @@ Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes. in JAX, including support for different scaling modes and datatypes.
""" """
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Dict, Union, Sequence from typing import Optional, Tuple, Dict, Union, Sequence, Type
from functools import reduce from functools import reduce
import operator import operator
...@@ -26,7 +28,7 @@ from .. import cpp_extensions as tex ...@@ -26,7 +28,7 @@ from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"QuantizeConfig", "get_quantize_config",
"fp8_autocast", "fp8_autocast",
"is_fp8_available", "is_fp8_available",
"update_collections", "update_collections",
...@@ -34,12 +36,15 @@ __all__ = [ ...@@ -34,12 +36,15 @@ __all__ = [
"apply_padding_to_scale_inv", "apply_padding_to_scale_inv",
"remove_padding_from_scale_inv", "remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
"TensorSource",
] ]
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
Collection = Union[Dict, FrozenDict] Collection = Union[Dict, FrozenDict]
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture. """Check if delayed scaling FP8 is supported on the given GPU architecture.
...@@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format): ...@@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format):
return jnp.bfloat16, jnp.bfloat16 return jnp.bfloat16, jnp.bfloat16
class TensorSource(Enum):
"""Enumeration for where a tensor's data comes from."""
# Input data
X = 0
# Model parameters
KERNEL = 1
# Gradients in the backward pass
DGRAD = 2
class AmaxComputeAlgo(Enum): class AmaxComputeAlgo(Enum):
"""Enumeration for AMAX computation algorithms. """Enumeration for AMAX computation algorithms.
...@@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum): ...@@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent" MOST_RECENT = "most_recent"
def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: @dataclass
"""Convert recipe.Recipe to ScalingMode. class BaseQuantizeConfig(ABC):
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!")
class QuantizeConfig:
"""Configuration class for quantization settings. """Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats, This class manages global quantization settings including FP8 formats,
...@@ -204,14 +200,13 @@ class QuantizeConfig: ...@@ -204,14 +200,13 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
INFERENCE_MODE: Whether to enable optimization for inference INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
""" """
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
COLLECTION_NAME: str = "fp8_metas" COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
...@@ -219,61 +214,82 @@ class QuantizeConfig: ...@@ -219,61 +214,82 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
INFERENCE_MODE: bool = False INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling # DelayedScaling
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
@staticmethod def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
def is_fp8_enabled(): """Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
self.INITIALIZED = True
self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
self.FP8_FORMAT = fp8_recipe.fp8_format
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT)
def is_fp8_enabled(self) -> bool:
"""Check if FP8 quantization is enabled. """Check if FP8 quantization is enabled.
Returns: Returns:
bool: True if quantization is enabled, False otherwise bool: True if quantization is enabled, False otherwise
""" """
return QuantizeConfig.INITIALIZED return self.INITIALIZED
@classmethod @abstractmethod
def initialize(cls, fp8_recipe: recipe.Recipe) -> None: def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Initialize the quantization configuration. """Gets the scaling mode for a specific tensor's usage type.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization tensor_source: The usage type for which to get the scaling mode.
Returns:
The scaling mode for the specified usage type.
"""
def is_supported(self) -> tuple[bool, str]:
"""Check if this QuantizeConfig class is supported on the available devices.
Returns:
bool: True if the class is supported, False otherwise
str: Reason for being unsupported, if applicable.
""" """
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 x_scaling_mode = self.get_scaling_mode(TensorSource.X)
cls.FP8_FORMAT = fp8_recipe.fp8_format kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
is_supported, reason = is_fp8_available(scaling_mode=scaling_mode)
@classmethod if not is_supported:
def finalize(cls) -> None: return is_supported, reason
"""Reset the quantization configuration to default values.""" return True, None
cls.INITIALIZED = False
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID class NoOpQuantizeConfig(BaseQuantizeConfig):
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) """Configuration class higher-precision non-quantized operation."""
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
cls.FP8_2X_ACC_DGRAD = False """Initialize no-op configuration."""
cls.FP8_2X_ACC_WGRAD = False raise NotImplementedError(
cls.SCALING_MODE = ScalingMode.NO_SCALING "NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
cls.INFERENCE_MODE = False " higher-precision when no quantized recipe is set."
# DelayedScaling )
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.NO_SCALING
class DelayedScalingQuantizeConfig:
class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for delayed scaling FP8 recipe. """Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode. FP8 quantization mode.
""" """
@staticmethod def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize delayed scaling FP8 configuration. """Initialize delayed scaling FP8 configuration.
Args: Args:
...@@ -282,6 +298,8 @@ class DelayedScalingQuantizeConfig: ...@@ -282,6 +298,8 @@ class DelayedScalingQuantizeConfig:
Raises: Raises:
AssertionError: If recipe parameters are not supported AssertionError: If recipe parameters are not supported
""" """
super().initialize_from_recipe(fp8_recipe)
assert fp8_recipe.amax_compute_algo in [ assert fp8_recipe.amax_compute_algo in [
"max", "max",
"most_recent", "most_recent",
...@@ -291,71 +309,88 @@ class DelayedScalingQuantizeConfig: ...@@ -291,71 +309,88 @@ class DelayedScalingQuantizeConfig:
), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."
cls = QuantizeConfig self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
string_to_amax_compute_algo = { string_to_amax_compute_algo = {
"max": AmaxComputeAlgo.MAX, "max": AmaxComputeAlgo.MAX,
"most_recent": AmaxComputeAlgo.MOST_RECENT, "most_recent": AmaxComputeAlgo.MOST_RECENT,
} }
cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
cls.FP8_2X_ACC_DGRAD = True self.FP8_2X_ACC_DGRAD = True
cls.FP8_2X_ACC_WGRAD = True self.FP8_2X_ACC_WGRAD = True
@staticmethod def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
def finalize() -> None: """Gets the scaling mode for a specific tensor's usage type."""
"""Reset the delayed scaling configuration.""" return ScalingMode.DELAYED_TENSOR_SCALING
QuantizeConfig.finalize()
class CurrentScalingQuantizeConfig: class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for current scaling FP8 recipe. """Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling This class provides specific initialization and finalization for current scaling
FP8 quantization mode. FP8 quantization mode.
""" """
@staticmethod def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration. """Initialize current scaling FP8 configuration.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The FP8 recipe to use for initialization
""" """
cls = QuantizeConfig super().initialize_from_recipe(fp8_recipe)
cls.initialize(fp8_recipe) self.AMAX_HISTORY_LEN = 0
cls.AMAX_HISTORY_LEN = 0
@staticmethod def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
def finalize() -> None: """Gets the scaling mode for a specific tensor's usage type."""
"""Reset the current scaling configuration.""" return ScalingMode.CURRENT_TENSOR_SCALING
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig: class BlockScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for block scaling FP8 recipe. """Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling This class provides specific initialization and finalization for block scaling
FP8 quantization mode. FP8 quantization mode.
""" """
@staticmethod def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None:
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling FP8 configuration.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The FP8 recipe to use for initialization
""" """
cls = QuantizeConfig super().initialize_from_recipe(fp8_recipe)
cls.initialize(fp8_recipe) self.AMAX_HISTORY_LEN = 0
cls.AMAX_HISTORY_LEN = 0
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
return ScalingMode.MXFP8_1D_SCALING
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
@staticmethod def get_quantize_config():
def finalize() -> None: """Global instance of BaseQuantizeConfig set by fp8_autocast context."""
"""Reset the block scaling configuration.""" return _QUANTIZE_CONFIG
QuantizeConfig.finalize()
def get_quantize_config_class(
fp8_recipe: recipe.Recipe,
) -> Type[BaseQuantizeConfig]:
"""Get the quantization configuration based on the FP8 recipe.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Returns:
The quantization config class corresponding to the given recipe.
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return CurrentScalingQuantizeConfig
raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
@contextmanager @contextmanager
...@@ -404,22 +439,22 @@ def fp8_autocast( ...@@ -404,22 +439,22 @@ def fp8_autocast(
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling() fp8_recipe = recipe.DelayedScaling()
Config = DelayedScalingQuantizeConfig global _QUANTIZE_CONFIG
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig old_quantize_config = _QUANTIZE_CONFIG
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig _QUANTIZE_CONFIG = NoOpQuantizeConfig()
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe)) _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)()
assert fp8_available, reason_for_no_fp8 is_supported, reason = _QUANTIZE_CONFIG.is_supported()
assert is_supported, reason
Config.initialize(fp8_recipe) _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe)
yield yield
finally: finally:
Config.finalize() _QUANTIZE_CONFIG = old_quantize_config
def get_delayed_scaling(): def get_delayed_scaling():
...@@ -437,12 +472,12 @@ def get_delayed_scaling(): ...@@ -437,12 +472,12 @@ def get_delayed_scaling():
an instance of DelayedScaling which is set via fp8_autocast. an instance of DelayedScaling which is set via fp8_autocast.
""" """
amax_compute_algo = ( amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
) )
return recipe.DelayedScaling( return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN), margin=int(get_quantize_config().MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT, fp8_format=get_quantize_config().FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN, amax_history_len=get_quantize_config().AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
...@@ -581,6 +616,3 @@ def apply_padding_to_scale_inv( ...@@ -581,6 +616,3 @@ def apply_padding_to_scale_inv(
# Pad the scales with the lowest representable value (2^-127) and return # Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
...@@ -21,9 +21,10 @@ from transformer_engine.common import recipe ...@@ -21,9 +21,10 @@ from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import ( from .helper import (
QuantizeConfig, get_quantize_config,
get_quantize_config_class,
AmaxComputeAlgo, AmaxComputeAlgo,
_get_scaling_mode, TensorSource,
) )
from .device_utils import is_fp8_gemm_with_all_layouts_supported from .device_utils import is_fp8_gemm_with_all_layouts_supported
...@@ -56,7 +57,7 @@ def compute_scale_from_amax( ...@@ -56,7 +57,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None: if scale is None:
scale = jnp.ones((1,)) scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf return sf
...@@ -234,7 +235,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -234,7 +235,7 @@ class CurrentScaleQuantizer(Quantizer):
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,)) amax = jnp.max(jnp.abs(x)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.astype(compute_dtype) * scale scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
...@@ -320,7 +321,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -320,7 +321,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
) )
def tree_flatten(self): def tree_flatten(self):
...@@ -397,7 +398,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -397,7 +398,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value Updated scale value
""" """
# 2. Calculate the current scale # 2. Calculate the current scale
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True) amax = jnp.max(amax_history, axis=-1, keepdims=True)
else: else:
amax = amax_history[0:1] amax = amax_history[0:1]
...@@ -827,12 +828,21 @@ class QuantizerFactory: ...@@ -827,12 +828,21 @@ class QuantizerFactory:
@staticmethod @staticmethod
def _create_set( def _create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs x_scaling_mode,
kernel_scaling_mode,
grad_scaling_mode,
fwd_dtype,
bwd_dtype,
is_2x2x,
n_groups,
**kwargs,
) -> QuantizerSet: ) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes. """Create a set of quantizers for forward and backward passes.
Args: Args:
scaling_mode: Scaling mode to use x_scaling_mode: Scaling mode to use for input tensor 'x'
kernel_scaling_mode: Scaling mode to use for kernel tensor
grad_scaling_mode: Scaling mode to use for gradient tensor
fwd_dtype: Data type for forward pass fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization is_2x2x: Whether to use 2x2x quantization
...@@ -846,9 +856,9 @@ class QuantizerFactory: ...@@ -846,9 +856,9 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling(): if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE: if get_quantize_config().INFERENCE_MODE:
q_layout_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
...@@ -868,12 +878,12 @@ class QuantizerFactory: ...@@ -868,12 +878,12 @@ class QuantizerFactory:
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_kernel = QuantizerFactory.create( q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
) )
q_dgrad = QuantizerFactory.create( q_dgrad = QuantizerFactory.create(
1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
) )
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
...@@ -892,10 +902,10 @@ class QuantizerFactory: ...@@ -892,10 +902,10 @@ class QuantizerFactory:
Args: Args:
n_quantizer_sets: Number of quantizer sets to create n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
n_groups: n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
...@@ -912,27 +922,44 @@ class QuantizerFactory: ...@@ -912,27 +922,44 @@ class QuantizerFactory:
) )
if fp8_recipe is not None: if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic quantize_config = get_quantize_config_class(fp8_recipe)()
scaling_mode = _get_scaling_mode(fp8_recipe) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
elif scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else: else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None: if is_2x2x is None:
if scaling_mode.is_1d_block_scaling(): # TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
is_2x2x = True is_2x2x = True
elif scaling_mode.is_tensor_scaling(): elif x_scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported() is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!" assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = [] q_set = []
for _ in range(n_quantizer_sets): for _ in range(n_quantizer_sets):
q_set.append( q_set.append(
QuantizerFactory._create_set( QuantizerFactory._create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs x_scaling_mode=x_scaling_mode,
kernel_scaling_mode=kernel_scaling_mode,
grad_scaling_mode=grad_scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
**kwargs,
) )
) )
......
...@@ -396,7 +396,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -396,7 +396,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
The quantize layout for the tensor usage The quantize layout for the tensor usage
""" """
# If we need to support 1x1x for inference in the future # If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE: # if get_quantize_config().INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS: # if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE # return QuantizeLayout.ROWWISE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment