Unverified Commit a0754757 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Improve support and testing for direct recipe usage without autocast contexts (#2366)



* Refactor to avoid storing a global quantization config so direct recipe passing works as intended
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_split_accumulator for current scaling recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix tests that pass direct recipe and were missing quantize meta set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Revert "fix use_split_accumulator for current scaling recipe"

This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix amax_history post_init
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update transformer_engine/jax/quantize/quantizer.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b88f727b
......@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
QuantizeMetaSet,
QuantizeMeta,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
......@@ -1457,7 +1459,12 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
......@@ -1516,7 +1523,12 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
......@@ -1605,6 +1617,9 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm":
......
......@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import (
get_quantize_config,
get_global_quantize_recipe,
get_quantize_config_with_recipe,
ScalingMode,
is_fp8_available,
update_collections,
......@@ -358,7 +359,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params)
if get_quantize_config().is_fp8_enabled():
if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
......@@ -368,14 +369,24 @@ class BaseRunner:
test_layer,
)
if (
get_quantize_config().get_scaling_mode(TensorSource.X)
get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
TensorSource.X
)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop(
updated_state[0], get_quantize_config().COLLECTION_NAME
updated_state[0],
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME,
)
test_others = update_collections(
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
{
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME: updated_quantize_meta
},
test_others,
)
del updated_quantize_meta
del updated_state
......
......@@ -4,6 +4,7 @@
import unittest
from functools import partial
from abc import ABC, abstractmethod
import flax
import jax
......@@ -13,6 +14,7 @@ from flax import linen as nn
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
......@@ -21,13 +23,13 @@ from transformer_engine.common.recipe import (
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
get_global_quantize_recipe,
get_quantize_config_with_recipe,
get_supported_quantization_recipes,
is_scaling_mode_supported,
ScalingMode,
update_collections,
TensorSource,
QuantizerFactory,
QuantizeLayout,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
......@@ -49,16 +51,17 @@ def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0]
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x
return x, (inner_quantizer_set,)
def quantizer_check_bwd(ctx, g):
return (g,)
def quantizer_check_bwd(assertion_func, ctx, g):
(inner_quantizer_set,) = ctx
return (inner_quantizer_set, g)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
......@@ -69,10 +72,11 @@ class TestModule(TransformerEngineBase):
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
direct_recipe: Recipe
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set()
quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe)
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
......@@ -97,167 +101,239 @@ class TestHelper(unittest.TestCase):
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def assert_fp8_format(quantizer, tensor_source, fp8_format):
if fp8_format == FP8Format.HYBRID:
if tensor_source == TensorSource.DGRAD:
assert quantizer.q_dtype == jnp.float8_e5m2
else:
assert quantizer.q_dtype == jnp.float8_e4m3fn
elif fp8_format == FP8Format.E4M3:
assert quantizer.q_dtype == jnp.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 format: {fp8_format}")
def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
class RecipeAssertionBase(ABC):
"""Base class for defining recipe assertions."""
@abstractmethod
def assert_context(self, ref_recipe, quantize_config):
"""Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context.
Args:
ref_recipe: The reference quantization recipe.
quantize_config: The quantization configuration to be checked.
"""
pass
@abstractmethod
def assert_quantizers(self, ref_recipe, quantizer, tensor_source):
"""Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction.
Args:
ref_recipe: The reference quantization recipe.
quantizer: The quantizer to be checked.
tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD).
"""
pass
class DelayedScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.MARGIN == ref_recipe.margin
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len
assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
assert (
quantize_config.get_scaling_mode(tensor_source)
== ScalingMode.DELAYED_TENSOR_SCALING
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
assert quantizer.margin == ref_recipe.margin
assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo
assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,)
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class CurrentScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
assert (
quantize_config.get_scaling_mode(tensor_source)
== ScalingMode.CURRENT_TENSOR_SCALING
)
def _compare_nvfp4_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class MXFP8RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource:
assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING
def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class NVFP4RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1]
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
self.assertEqual(
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
)
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht)
self.assertEqual(
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization
)
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode
assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding
assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht
assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization
def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source):
if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization:
assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
else:
assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
def assertion_func(quantizer, tensor_source):
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
assert quantizer.stochastic_rounding_rng_state is None
else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state)
assert quantizer.stochastic_rounding_rng_state is not None
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht
and not ref_recipe.disable_rht
)
self.assertEqual(quantizer.use_rht, expected_rht)
assert quantizer.use_rht == expected_rht
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
self._check_default_state()
def _check_default_state(self):
self.assertEqual(get_global_quantize_recipe(), None)
with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase):
"""Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts."""
assert_context_func = cls().assert_context
assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe)
self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func)
self._test_recipe_direct(quantization_recipe, assert_quantizer_func)
def _test_recipe_autocast(
self, quantization_recipe, assert_context_func, assert_quantizer_func
):
"""Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module."""
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()):
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()):
quantize_config = self._get_global_quantize_config()
assert_context_func(quantization_recipe, quantize_config)
self._test_quantizer_in_model(assert_quantizer_func)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_current_scaling(self):
def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func):
"""Tests a quantization recipe by directly passing it to a test module and verifying the quantizers."""
self._check_default_state()
with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe)
self._check_default_state()
self._check_default_state()
def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None):
"""Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary.
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
Args:
assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None.
direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts.
"""
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
self._check_default_state()
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
def _get_global_quantize_config(self):
quantization_recipe = get_global_quantize_recipe()
assert quantization_recipe is not None, "No global quantization recipe set"
quantize_config = get_quantize_config_with_recipe(quantization_recipe)
assert (
quantize_config.is_fp8_enabled()
), "Quantization not enabled in global quantize config"
return quantize_config
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
self._test_recipe(
quantization_recipe=DelayedScaling(),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_current_scaling(self):
self._test_recipe(
quantization_recipe=Float8CurrentScaling(),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID),
cls=CurrentScalingRecipeAssertion,
)
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_autocast_mxfp8_block_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
self._test_recipe(
quantization_recipe=MXFP8BlockScaling(),
cls=MXFP8RecipeAssertion,
)
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
self._test_recipe(
quantization_recipe=NVFP4BlockScaling(),
cls=NVFP4RecipeAssertion,
)
self._test_recipe(
quantization_recipe=NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
),
cls=NVFP4RecipeAssertion,
)
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state()
class TestJaxprAndHlo:
......
......@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods.
This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
......@@ -38,12 +38,13 @@ from ..quantize import (
ScalingMode,
Quantizer,
GroupedQuantizer,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -1246,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]:
......@@ -1258,6 +1259,13 @@ def _te_gemm(
DeprecationWarning,
)
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
......@@ -1720,10 +1728,15 @@ def _jax_gemm(
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = (
jax.lax.Precision.HIGHEST
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
......@@ -820,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling
......@@ -1227,7 +1227,7 @@ def grouped_quantize(
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in (
......
......@@ -27,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage,
get_quantize_config,
)
......@@ -95,7 +94,7 @@ def dense(
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......
......@@ -33,10 +33,11 @@ from ..cpp_extensions import (
)
from ..quantize import (
QuantizerFactory,
get_quantize_config,
get_global_quantize_recipe,
QuantizeMetaSet,
TensorSource,
get_quantize_config_with_recipe,
noop_quantizer_set,
)
PRNGKey = Any
......@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM.
"""
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = (
variable_collection
if variable_collection is not None
else get_quantize_config().COLLECTION_NAME
else quantize_config.COLLECTION_NAME
)
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x"
)
......@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
if self.use_bias:
......@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else:
bias = None
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs,
......@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
)
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
get_quantize_config().is_fp8_enabled()
ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output
and self.enable_layernorm
)
......@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
......@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape,
self.dtype,
)
if not get_quantize_config().is_fp8_enabled():
if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......
......@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
get_quantize_config,
)
......@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel
"""
if not get_quantize_config().is_fp8_enabled():
if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......
......@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
TensorUsage,
get_quantize_config,
)
......@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled():
if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype)
......
......@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode
from .device_utils import get_device_compute_capability
__all__ = [
"get_quantize_config",
"get_global_quantize_recipe",
"get_quantize_config_with_recipe",
"autocast",
"fp8_autocast",
......@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
(self.AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
return QuantizeMeta(
margin=self.MARGIN,
amax_compute_algo=self.AMAX_COMPUTE_ALGO,
scale=scale,
amax_history=amax_history,
)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
......@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
)
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class(
fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]:
......@@ -687,6 +684,8 @@ def get_quantize_config_class(
Returns:
The quantization config class corresponding to the given recipe.
"""
if fp8_recipe is None:
return NoOpQuantizeConfig
if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, MXFP8BlockScaling):
......@@ -701,10 +700,23 @@ def get_quantize_config_class(
def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)()
if fp8_recipe is not None:
config.initialize_from_recipe(fp8_recipe)
return config
_GLOBAL_RECIPE: Optional[Recipe] = None
def get_global_quantize_recipe() -> Optional[Recipe]:
"""Get the global quantization recipe if set.
Returns:
The global quantization recipe or None if not set.
"""
return _GLOBAL_RECIPE
@contextmanager
def autocast(
enabled: bool = False,
......@@ -751,22 +763,21 @@ def autocast(
if recipe is None:
recipe = DelayedScaling()
global _QUANTIZE_CONFIG
global _GLOBAL_RECIPE
old_quantize_config = _QUANTIZE_CONFIG
old_global_recipe = _GLOBAL_RECIPE
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
_GLOBAL_RECIPE = None
try:
with global_shard_guard(mesh_resource):
if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(recipe)()
is_supported, reason = _QUANTIZE_CONFIG.is_supported()
_GLOBAL_RECIPE = recipe
is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported()
assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield
finally:
_QUANTIZE_CONFIG = old_quantize_config
_GLOBAL_RECIPE = old_global_recipe
@contextmanager
......
......@@ -28,7 +28,7 @@ from .tensor import (
NoScaleTensor,
)
from .helper import (
get_quantize_config,
get_global_quantize_recipe,
get_quantize_config_with_recipe,
AmaxComputeAlgo,
TensorSource,
......@@ -50,7 +50,7 @@ __all__ = [
def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Compute scale from amax value.
......@@ -64,7 +64,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
......@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
"""
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
......@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0)
scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
......@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
margin: Margin value for scale computation
amax_compute_algo: Algorithm for computing amax
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
margin: float = 0.0
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
)
amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32))
def __post_init__(self):
assert self.margin is not None, "margin must be specified"
assert self.amax_compute_algo is not None, "amax_compute_algo must be specified"
assert self.amax_history is not None, "amax_history must be specified"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
self.q_layout,
self.data_layout,
self.checkpoint_name,
self.margin,
self.amax_compute_algo,
)
return (children, aux_data)
......@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns:
Updated AMAX history
"""
amax_history = amax_history.at[0].set(new_amax[0])
amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0])
return amax_history
@staticmethod
@partial(jax.jit, static_argnums=(2,))
def _compute_scale(amax_history, scale, q_dtype):
@partial(jax.jit, static_argnums=(2, 3, 4))
def _compute_scale(
amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float
):
"""Compute new scale based on AMAX history.
Args:
......@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value
"""
# 2. Calculate the current scale
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
if amax_compute_algo is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
return compute_scale_from_amax(amax, q_dtype, scale=scale)
return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale)
@staticmethod
@jax.jit
......@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
new_amax: New maximum absolute value to add to history
"""
amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
self.scale = self._compute_scale(
amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin
)
self.amax_history = self._roll_and_reset_amax_history(amax_history)
......@@ -1124,6 +1136,7 @@ class QuantizerFactory:
bwd_dtype,
is_2x2x,
n_groups,
is_inference_mode=False,
checkpoint_name: Optional[str] = None,
**kwargs,
) -> QuantizerSet:
......@@ -1137,6 +1150,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
n_groups
is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet
checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization
......@@ -1149,7 +1163,7 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if get_quantize_config().INFERENCE_MODE:
if is_inference_mode:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -1206,10 +1220,10 @@ class QuantizerFactory:
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe
fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe
bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe
is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe
n_groups:
checkpoint_name: Optional name for checkpointing quantizations
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
......@@ -1226,25 +1240,46 @@ class QuantizerFactory:
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
# TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly.
assert (
fp8_recipe is None or "quantize_meta_set" in kwargs
), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs."
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
if fp8_recipe is not None:
assert scaling_mode is None, (
"scaling_mode should not be specified when fp8_recipe is provided either directly"
" or through an autocast context."
)
assert fwd_dtype is None, (
"fwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
assert bwd_dtype is None, (
"bwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE
is_inference_mode = quantize_config.INFERENCE_MODE
else:
if scaling_mode is not None:
x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode
else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
# TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired.
x_scaling_mode = ScalingMode.NO_SCALING
kernel_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = ScalingMode.NO_SCALING
is_inference_mode = False
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None:
# TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling():
......@@ -1253,7 +1288,6 @@ class QuantizerFactory:
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
......@@ -1267,6 +1301,7 @@ class QuantizerFactory:
bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x,
n_groups=n_groups,
is_inference_mode=is_inference_mode,
checkpoint_name=checkpoint_name,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment