Unverified Commit a0754757 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Improve support and testing for direct recipe usage without autocast contexts (#2366)



* Refactor to avoid storing a global quantization config so direct recipe passing works as intended
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix use_split_accumulator for current scaling recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix tests that pass direct recipe and were missing quantize meta set
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Revert "fix use_split_accumulator for current scaling recipe"

This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix amax_history post_init
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update transformer_engine/jax/quantize/quantizer.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix ci failures
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* fix ci issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* address comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b88f727b
...@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import ( ...@@ -40,6 +40,8 @@ from transformer_engine.jax.quantize import (
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
QuantizeMetaSet,
QuantizeMeta,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -1457,7 +1459,12 @@ class TestDense: ...@@ -1457,7 +1459,12 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
n_iterations = 3 if recipe.delayed() else 1 n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
...@@ -1516,7 +1523,12 @@ class TestFusedDense: ...@@ -1516,7 +1523,12 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm": if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
...@@ -1605,6 +1617,9 @@ class TestFusedDense: ...@@ -1605,6 +1617,9 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
fp8_recipe=recipe, fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
) )
if norm_type == "layernorm": if norm_type == "layernorm":
......
...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer ...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe,
ScalingMode, ScalingMode,
is_fp8_available, is_fp8_available,
update_collections, update_collections,
...@@ -358,7 +359,7 @@ class BaseRunner: ...@@ -358,7 +359,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params) ref_params, test_params = self._sync_params(ref_params, test_params)
if get_quantize_config().is_fp8_enabled(): if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs, inputs,
...@@ -368,14 +369,24 @@ class BaseRunner: ...@@ -368,14 +369,24 @@ class BaseRunner:
test_layer, test_layer,
) )
if ( if (
get_quantize_config().get_scaling_mode(TensorSource.X) get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
TensorSource.X
)
== ScalingMode.DELAYED_TENSOR_SCALING == ScalingMode.DELAYED_TENSOR_SCALING
): ):
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], get_quantize_config().COLLECTION_NAME updated_state[0],
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME,
) )
test_others = update_collections( test_others = update_collections(
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others {
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME: updated_quantize_meta
},
test_others,
) )
del updated_quantize_meta del updated_quantize_meta
del updated_state del updated_state
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import unittest import unittest
from functools import partial from functools import partial
from abc import ABC, abstractmethod
import flax import flax
import jax import jax
...@@ -13,6 +14,7 @@ from flax import linen as nn ...@@ -13,6 +14,7 @@ from flax import linen as nn
from utils import assert_allclose, pytest_parametrize_wrapper from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
Recipe,
DelayedScaling, DelayedScaling,
MXFP8BlockScaling, MXFP8BlockScaling,
Float8CurrentScaling, Float8CurrentScaling,
...@@ -21,13 +23,13 @@ from transformer_engine.common.recipe import ( ...@@ -21,13 +23,13 @@ from transformer_engine.common.recipe import (
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe,
get_supported_quantization_recipes, get_supported_quantization_recipes,
is_scaling_mode_supported, is_scaling_mode_supported,
ScalingMode, ScalingMode,
update_collections, update_collections,
TensorSource, TensorSource,
QuantizerFactory,
QuantizeLayout, QuantizeLayout,
) )
from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.quantize.helper import _format2dtypes
...@@ -49,16 +51,17 @@ def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): ...@@ -49,16 +51,17 @@ def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
# Define a function with a custom VJP (vector-Jacobian product) # Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x): def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x) return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0]
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x): def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X) assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL) assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD) assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x return x, (inner_quantizer_set,)
def quantizer_check_bwd(ctx, g): def quantizer_check_bwd(assertion_func, ctx, g):
return (g,) (inner_quantizer_set,) = ctx
return (inner_quantizer_set, g)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd) quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x) return quantizer_check(outer_quantizer_set, assertion_func, x)
...@@ -69,10 +72,11 @@ class TestModule(TransformerEngineBase): ...@@ -69,10 +72,11 @@ class TestModule(TransformerEngineBase):
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable assertion_func: callable
direct_recipe: Recipe
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x):
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe)
return quantizer_check_vjp(quantizer_set, self.assertion_func, x) return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
...@@ -97,167 +101,239 @@ class TestHelper(unittest.TestCase): ...@@ -97,167 +101,239 @@ class TestHelper(unittest.TestCase):
self.assertEqual(updated_state["test2"], original_val) self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase): def assert_fp8_format(quantizer, tensor_source, fp8_format):
if fp8_format == FP8Format.HYBRID:
if tensor_source == TensorSource.DGRAD:
assert quantizer.q_dtype == jnp.float8_e5m2
else:
assert quantizer.q_dtype == jnp.float8_e4m3fn
elif fp8_format == FP8Format.E4M3:
assert quantizer.q_dtype == jnp.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 format: {fp8_format}")
def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled()) class RecipeAssertionBase(ABC):
"""Base class for defining recipe assertions."""
def _compare_delay_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin) @abstractmethod
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) def assert_context(self, ref_recipe, quantize_config):
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) """Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context.
self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo) Args:
ref_recipe: The reference quantization recipe.
def _compare_current_scaling(self, test): quantize_config: The quantization configuration to be checked.
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) """
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) pass
@abstractmethod
def assert_quantizers(self, ref_recipe, quantizer, tensor_source):
"""Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction.
Args:
ref_recipe: The reference quantization recipe.
quantizer: The quantizer to be checked.
tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD).
"""
pass
class DelayedScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.MARGIN == ref_recipe.margin
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len
assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo
for tensor_source in TensorSource: for tensor_source in TensorSource:
self.assertEqual( assert (
get_quantize_config().get_scaling_mode(tensor_source), quantize_config.get_scaling_mode(tensor_source)
ScalingMode.CURRENT_TENSOR_SCALING, == ScalingMode.DELAYED_TENSOR_SCALING
) )
def _compare_mxfp8_scaling(self, test): def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) assert quantizer.margin == ref_recipe.margin
assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo
assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,)
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class CurrentScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource: for tensor_source in TensorSource:
self.assertEqual( assert (
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING quantize_config.get_scaling_mode(tensor_source)
== ScalingMode.CURRENT_TENSOR_SCALING
) )
def _compare_nvfp4_scaling(self, test): def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0]) assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1]) assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class MXFP8RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource:
assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING
def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class NVFP4RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1]
for tensor_source in TensorSource: for tensor_source in TensorSource:
target_scaling_mode = ( target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING ScalingMode.NVFP4_2D_SCALING
if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING else ScalingMode.NVFP4_1D_SCALING
) )
self.assertEqual( assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding
) assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht
self.assertEqual( assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization
get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding
) def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source):
self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht) if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization:
self.assertEqual( assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization else:
) assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
def _compare_nvfp4_scaling_quantizers(self, test):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def assertion_func(quantizer, tensor_source): if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD: assert quantizer.stochastic_rounding_rng_state is None
self.assertIsNone(quantizer.stochastic_rounding_rng_state)
else: else:
self.assertIsNotNone(quantizer.stochastic_rounding_rng_state) assert quantizer.stochastic_rounding_rng_state is not None
expected_rht = ( expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE} and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not test.disable_rht and not ref_recipe.disable_rht
) )
self.assertEqual(quantizer.use_rht, expected_rht) assert quantizer.use_rht == expected_rht
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assertion_func)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs) class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason) def _check_default_state(self):
def test_autocast_delayed_scaling(self): self.assertEqual(get_global_quantize_recipe(), None)
self._check_default_state()
with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()): def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase):
self._check_default_state() """Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts."""
assert_context_func = cls().assert_context
assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe)
self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func)
self._test_recipe_direct(quantization_recipe, assert_quantizer_func)
def _test_recipe_autocast(
self, quantization_recipe, assert_context_func, assert_quantizer_func
):
"""Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module."""
self._check_default_state() self._check_default_state()
with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()):
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
self._check_default_state() self._check_default_state()
with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()):
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) quantize_config = self._get_global_quantize_config()
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()): assert_context_func(quantization_recipe, quantize_config)
self.assertTrue(get_quantize_config().is_fp8_enabled()) self._test_quantizer_in_model(assert_quantizer_func)
self._compare_delay_scaling(ds)
self._check_default_state() self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason) def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func):
def test_autocast_current_scaling(self): """Tests a quantization recipe by directly passing it to a test module and verifying the quantizers."""
self._check_default_state() self._check_default_state()
self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe)
with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
self._check_default_state() self._check_default_state()
self._check_default_state() def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None):
"""Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary.
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) Args:
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None.
self.assertTrue(get_quantize_config().is_fp8_enabled()) direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts.
self._compare_current_scaling(cs) """
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
self._check_default_state() jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) def _get_global_quantize_config(self):
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): quantization_recipe = get_global_quantize_recipe()
self.assertTrue(get_quantize_config().is_fp8_enabled()) assert quantization_recipe is not None, "No global quantization recipe set"
self._compare_current_scaling(cs) quantize_config = get_quantize_config_with_recipe(quantization_recipe)
assert (
quantize_config.is_fp8_enabled()
), "Quantization not enabled in global quantize config"
return quantize_config
self._check_default_state() @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
self._test_recipe(
quantization_recipe=DelayedScaling(),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_current_scaling(self):
self._test_recipe(
quantization_recipe=Float8CurrentScaling(),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID),
cls=CurrentScalingRecipeAssertion,
)
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_autocast_mxfp8_block_scaling(self): def test_autocast_mxfp8_block_scaling(self):
self._check_default_state() self._test_recipe(
quantization_recipe=MXFP8BlockScaling(),
with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()): cls=MXFP8RecipeAssertion,
self._check_default_state() )
self._check_default_state()
bs = MXFP8BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_autocast_nvfp4_block_scaling(self): def test_autocast_nvfp4_block_scaling(self):
self._check_default_state() self._test_recipe(
quantization_recipe=NVFP4BlockScaling(),
with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()): cls=NVFP4RecipeAssertion,
self._check_default_state() )
self._test_recipe(
self._check_default_state() quantization_recipe=NVFP4BlockScaling(
bs = NVFP4BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
bs = NVFP4BlockScaling(
disable_stochastic_rounding=True, disable_stochastic_rounding=True,
disable_rht=True, disable_rht=True,
disable_2d_quantization=True, disable_2d_quantization=True,
),
cls=NVFP4RecipeAssertion,
) )
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state()
class TestJaxprAndHlo: class TestJaxprAndHlo:
......
...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F ...@@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
""" """
Helper function to manage primitive states by name without modifying environment variables. Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the get_quantize_config().initialize() methods. This helper is used in the get_quantize_config_with_recipe().initialize() methods.
Args: Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
...@@ -38,12 +38,13 @@ from ..quantize import ( ...@@ -38,12 +38,13 @@ from ..quantize import (
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
get_quantize_config,
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe,
get_global_quantize_recipe,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -1246,7 +1247,7 @@ def _te_gemm( ...@@ -1246,7 +1247,7 @@ def _te_gemm(
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, use_split_accumulator: bool = None,
transpose_batch_sequence: bool = False, transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE, collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
...@@ -1258,6 +1259,13 @@ def _te_gemm( ...@@ -1258,6 +1259,13 @@ def _te_gemm(
DeprecationWarning, DeprecationWarning,
) )
if use_split_accumulator is None:
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
lhs_data = lhs lhs_data = lhs
rhs_data = rhs rhs_data = rhs
...@@ -1720,10 +1728,15 @@ def _jax_gemm( ...@@ -1720,10 +1728,15 @@ def _jax_gemm(
assert ( assert (
rhs.scaling_mode == lhs.scaling_mode rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
# TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also
# use context of the GEMM type so we can decide between fprop, dgrad, and wgrad
use_split_accumulator = get_quantize_config_with_recipe(
get_global_quantize_recipe()
).FP8_2X_ACC_FPROP
precision = ( precision = (
jax.lax.Precision.HIGHEST jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
) )
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
......
...@@ -820,7 +820,7 @@ def _quantize_dbias_impl( ...@@ -820,7 +820,7 @@ def _quantize_dbias_impl(
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling # Make sure to reset amax to zeros for DelayedScaling
...@@ -1227,7 +1227,7 @@ def grouped_quantize( ...@@ -1227,7 +1227,7 @@ def grouped_quantize(
) )
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups): for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
scale = scale.at[i].set(tmp_scale[0]) scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in ( is_tensor_scaling = quantizer.scaling_mode in (
......
...@@ -27,7 +27,6 @@ from .quantize import ( ...@@ -27,7 +27,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -95,7 +94,7 @@ def dense( ...@@ -95,7 +94,7 @@ def dense(
if transpose_batch_sequence: if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!") warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
......
...@@ -33,10 +33,11 @@ from ..cpp_extensions import ( ...@@ -33,10 +33,11 @@ from ..cpp_extensions import (
) )
from ..quantize import ( from ..quantize import (
QuantizerFactory, QuantizerFactory,
get_quantize_config, get_global_quantize_recipe,
QuantizeMetaSet, QuantizeMetaSet,
TensorSource, TensorSource,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
noop_quantizer_set,
) )
PRNGKey = Any PRNGKey = Any
...@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -355,17 +356,17 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
collection_name = ( collection_name = (
variable_collection variable_collection
if variable_collection is not None if variable_collection is not None
else get_quantize_config().COLLECTION_NAME else quantize_config.COLLECTION_NAME
) )
if fp8_recipe is None:
quantize_config = get_quantize_config()
else:
quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_meta = quantize_config.get_quantize_flax_meta( x_meta = quantize_config.get_quantize_flax_meta(
self, collection_name, postfix, TensorSource.X, "x" self, collection_name, postfix, TensorSource.X, "x"
) )
...@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase): ...@@ -492,7 +493,11 @@ class DenseGeneral(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
...@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -505,9 +510,6 @@ class DenseGeneral(TransformerEngineBase):
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set(
quantization_checkpoint_name=self.quantization_checkpoint_name
)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, inputs,
...@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -712,7 +714,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
) )
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -763,7 +765,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape, kernel_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1042,7 +1044,7 @@ class LayerNormMLP(TransformerEngineBase):
# TODO(Phuong): use fuse_layernorm for high-precision # TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented # when NoOpQuantizer and Tensor are implemented
fuse_layernorm = ( fuse_layernorm = (
get_quantize_config().is_fp8_enabled() ffn1_quantizer_set != noop_quantizer_set
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
...@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1128,7 +1130,7 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn1_quantizer_set == noop_quantizer_set:
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
...@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1140,7 +1142,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2_shape, kernel_2_shape,
self.dtype, self.dtype,
) )
if not get_quantize_config().is_fp8_enabled(): if ffn2_quantizer_set == noop_quantizer_set:
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
......
...@@ -23,7 +23,6 @@ from .quantize import ( ...@@ -23,7 +23,6 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -73,7 +72,7 @@ def layernorm_dense( ...@@ -73,7 +72,7 @@ def layernorm_dense(
- Quantization is applied to both the normalized input and kernel - Quantization is applied to both the normalized input and kernel
""" """
if not get_quantize_config().is_fp8_enabled(): if quantizer_set == noop_quantizer_set:
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
......
...@@ -28,7 +28,6 @@ from .quantize import ( ...@@ -28,7 +28,6 @@ from .quantize import (
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
get_quantize_config,
) )
...@@ -114,7 +113,7 @@ def layernorm_mlp( ...@@ -114,7 +113,7 @@ def layernorm_mlp(
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
if not get_quantize_config().is_fp8_enabled(): if quantizer_sets == (noop_quantizer_set, noop_quantizer_set):
input_dtype = x.dtype input_dtype = x.dtype
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
......
...@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode ...@@ -46,7 +46,7 @@ from .scaling_modes import ScalingMode
from .device_utils import get_device_compute_capability from .device_utils import get_device_compute_capability
__all__ = [ __all__ = [
"get_quantize_config", "get_global_quantize_recipe",
"get_quantize_config_with_recipe", "get_quantize_config_with_recipe",
"autocast", "autocast",
"fp8_autocast", "fp8_autocast",
...@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -475,7 +475,12 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
(self.AMAX_HISTORY_LEN,), (self.AMAX_HISTORY_LEN,),
jnp.float32, jnp.float32,
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(
margin=self.MARGIN,
amax_compute_algo=self.AMAX_COMPUTE_ALGO,
scale=scale,
amax_history=amax_history,
)
class CurrentScalingQuantizeConfig(BaseQuantizeConfig): class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
...@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -669,14 +674,6 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
) )
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
def get_quantize_config():
"""Global instance of BaseQuantizeConfig set by autocast context."""
return _QUANTIZE_CONFIG
def get_quantize_config_class( def get_quantize_config_class(
fp8_recipe: Recipe, fp8_recipe: Recipe,
) -> Type[BaseQuantizeConfig]: ) -> Type[BaseQuantizeConfig]:
...@@ -687,6 +684,8 @@ def get_quantize_config_class( ...@@ -687,6 +684,8 @@ def get_quantize_config_class(
Returns: Returns:
The quantization config class corresponding to the given recipe. The quantization config class corresponding to the given recipe.
""" """
if fp8_recipe is None:
return NoOpQuantizeConfig
if isinstance(fp8_recipe, DelayedScaling): if isinstance(fp8_recipe, DelayedScaling):
return DelayedScalingQuantizeConfig return DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, MXFP8BlockScaling): if isinstance(fp8_recipe, MXFP8BlockScaling):
...@@ -701,10 +700,23 @@ def get_quantize_config_class( ...@@ -701,10 +700,23 @@ def get_quantize_config_class(
def get_quantize_config_with_recipe(fp8_recipe: Recipe): def get_quantize_config_with_recipe(fp8_recipe: Recipe):
"""Get the quantization configuration object based on the FP8 recipe.""" """Get the quantization configuration object based on the FP8 recipe."""
config = get_quantize_config_class(fp8_recipe)() config = get_quantize_config_class(fp8_recipe)()
if fp8_recipe is not None:
config.initialize_from_recipe(fp8_recipe) config.initialize_from_recipe(fp8_recipe)
return config return config
_GLOBAL_RECIPE: Optional[Recipe] = None
def get_global_quantize_recipe() -> Optional[Recipe]:
"""Get the global quantization recipe if set.
Returns:
The global quantization recipe or None if not set.
"""
return _GLOBAL_RECIPE
@contextmanager @contextmanager
def autocast( def autocast(
enabled: bool = False, enabled: bool = False,
...@@ -751,22 +763,21 @@ def autocast( ...@@ -751,22 +763,21 @@ def autocast(
if recipe is None: if recipe is None:
recipe = DelayedScaling() recipe = DelayedScaling()
global _QUANTIZE_CONFIG global _GLOBAL_RECIPE
old_quantize_config = _QUANTIZE_CONFIG old_global_recipe = _GLOBAL_RECIPE
_QUANTIZE_CONFIG = NoOpQuantizeConfig() _GLOBAL_RECIPE = None
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
if enabled: if enabled:
_QUANTIZE_CONFIG = get_quantize_config_class(recipe)() _GLOBAL_RECIPE = recipe
is_supported, reason = _QUANTIZE_CONFIG.is_supported() is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported()
assert is_supported, reason assert is_supported, reason
_QUANTIZE_CONFIG.initialize_from_recipe(recipe)
yield yield
finally: finally:
_QUANTIZE_CONFIG = old_quantize_config _GLOBAL_RECIPE = old_global_recipe
@contextmanager @contextmanager
......
...@@ -28,7 +28,7 @@ from .tensor import ( ...@@ -28,7 +28,7 @@ from .tensor import (
NoScaleTensor, NoScaleTensor,
) )
from .helper import ( from .helper import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
AmaxComputeAlgo, AmaxComputeAlgo,
TensorSource, TensorSource,
...@@ -50,7 +50,7 @@ __all__ = [ ...@@ -50,7 +50,7 @@ __all__ = [
def compute_scale_from_amax( def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Compute scale from amax value. """Compute scale from amax value.
...@@ -64,7 +64,7 @@ def compute_scale_from_amax( ...@@ -64,7 +64,7 @@ def compute_scale_from_amax(
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None: if scale is None:
scale = jnp.ones((1,)) scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = (fp8_max / amax) / (2**margin)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
...@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
""" """
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
...@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -254,8 +255,7 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32 compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0)
scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
scaled_x = x.data.astype(compute_dtype) * scale scaled_x = x.data.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
...@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
data_layout: Data layout string (default: "NT")
margin: Margin value for scale computation
amax_compute_algo: Algorithm for computing amax
scale: Current scaling factor scale: Current scaling factor
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING margin: float = 0.0
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32))
default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
) def __post_init__(self):
assert self.margin is not None, "margin must be specified"
assert self.amax_compute_algo is not None, "amax_compute_algo must be specified"
assert self.amax_history is not None, "amax_history must be specified"
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -352,6 +358,8 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
self.q_layout, self.q_layout,
self.data_layout, self.data_layout,
self.checkpoint_name, self.checkpoint_name,
self.margin,
self.amax_compute_algo,
) )
return (children, aux_data) return (children, aux_data)
...@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -407,12 +415,14 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Returns: Returns:
Updated AMAX history Updated AMAX history
""" """
amax_history = amax_history.at[0].set(new_amax[0]) amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0])
return amax_history return amax_history
@staticmethod @staticmethod
@partial(jax.jit, static_argnums=(2,)) @partial(jax.jit, static_argnums=(2, 3, 4))
def _compute_scale(amax_history, scale, q_dtype): def _compute_scale(
amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float
):
"""Compute new scale based on AMAX history. """Compute new scale based on AMAX history.
Args: Args:
...@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -424,12 +434,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value Updated scale value
""" """
# 2. Calculate the current scale # 2. Calculate the current scale
if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: if amax_compute_algo is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True) amax = jnp.max(amax_history, axis=-1, keepdims=True)
else: else:
amax = amax_history[0:1] amax = amax_history[0:1]
return compute_scale_from_amax(amax, q_dtype, scale=scale) return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale)
@staticmethod @staticmethod
@jax.jit @jax.jit
...@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -453,7 +463,9 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
new_amax: New maximum absolute value to add to history new_amax: New maximum absolute value to add to history
""" """
amax_history = self._update_amax_history(self.amax_history, new_amax) amax_history = self._update_amax_history(self.amax_history, new_amax)
self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype) self.scale = self._compute_scale(
amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin
)
self.amax_history = self._roll_and_reset_amax_history(amax_history) self.amax_history = self._roll_and_reset_amax_history(amax_history)
...@@ -1124,6 +1136,7 @@ class QuantizerFactory: ...@@ -1124,6 +1136,7 @@ class QuantizerFactory:
bwd_dtype, bwd_dtype,
is_2x2x, is_2x2x,
n_groups, n_groups,
is_inference_mode=False,
checkpoint_name: Optional[str] = None, checkpoint_name: Optional[str] = None,
**kwargs, **kwargs,
) -> QuantizerSet: ) -> QuantizerSet:
...@@ -1137,6 +1150,7 @@ class QuantizerFactory: ...@@ -1137,6 +1150,7 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization is_2x2x: Whether to use 2x2x quantization
n_groups n_groups
is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet
checkpoint_name: Optional name for checkpointing quantizations checkpoint_name: Optional name for checkpointing quantizations
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
...@@ -1149,7 +1163,7 @@ class QuantizerFactory: ...@@ -1149,7 +1163,7 @@ class QuantizerFactory:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if kernel_scaling_mode.is_1d_block_scaling(): if kernel_scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
if get_quantize_config().INFERENCE_MODE: if is_inference_mode:
q_layout_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
...@@ -1206,10 +1220,10 @@ class QuantizerFactory: ...@@ -1206,10 +1220,10 @@ class QuantizerFactory:
Args: Args:
n_quantizer_sets: Number of quantizer sets to create n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe
fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe
bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe
is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe
n_groups: n_groups:
checkpoint_name: Optional name for checkpointing quantizations checkpoint_name: Optional name for checkpointing quantizations
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
...@@ -1226,25 +1240,46 @@ class QuantizerFactory: ...@@ -1226,25 +1240,46 @@ class QuantizerFactory:
" scaling mode differs between x, kernel, and grad in the quantizer set." " scaling mode differs between x, kernel, and grad in the quantizer set."
) )
# TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly.
assert (
fp8_recipe is None or "quantize_meta_set" in kwargs
), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs."
if fp8_recipe is None:
fp8_recipe = get_global_quantize_recipe()
if fp8_recipe is not None: if fp8_recipe is not None:
assert scaling_mode is None, (
"scaling_mode should not be specified when fp8_recipe is provided either directly"
" or through an autocast context."
)
assert fwd_dtype is None, (
"fwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
assert bwd_dtype is None, (
"bwd_dtype should not be specified when fp8_recipe is provided either directly or"
" through an autocast context."
)
quantize_config = get_quantize_config_with_recipe(fp8_recipe) quantize_config = get_quantize_config_with_recipe(fp8_recipe)
x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
fwd_dtype = quantize_config.FWD_DTYPE fwd_dtype = quantize_config.FWD_DTYPE
bwd_dtype = quantize_config.BWD_DTYPE bwd_dtype = quantize_config.BWD_DTYPE
is_inference_mode = quantize_config.INFERENCE_MODE
else: else:
if scaling_mode is not None: if scaling_mode is not None:
x_scaling_mode = scaling_mode x_scaling_mode = scaling_mode
kernel_scaling_mode = scaling_mode kernel_scaling_mode = scaling_mode
grad_scaling_mode = scaling_mode grad_scaling_mode = scaling_mode
else: else:
x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) # TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired.
kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) x_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) kernel_scaling_mode = ScalingMode.NO_SCALING
grad_scaling_mode = ScalingMode.NO_SCALING
is_inference_mode = False
fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
if is_2x2x is None: if is_2x2x is None:
# TODO(Jeremy): check x, kernel, grad separately for 2x # TODO(Jeremy): check x, kernel, grad separately for 2x
if x_scaling_mode.is_1d_block_scaling(): if x_scaling_mode.is_1d_block_scaling():
...@@ -1253,7 +1288,6 @@ class QuantizerFactory: ...@@ -1253,7 +1288,6 @@ class QuantizerFactory:
is_2x2x = not is_fp8_gemm_with_all_layouts_supported() is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False is_2x2x = False
is_inference_mode = get_quantize_config().INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!" assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = [] q_set = []
...@@ -1267,6 +1301,7 @@ class QuantizerFactory: ...@@ -1267,6 +1301,7 @@ class QuantizerFactory:
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
is_2x2x=is_2x2x, is_2x2x=is_2x2x,
n_groups=n_groups, n_groups=n_groups,
is_inference_mode=is_inference_mode,
checkpoint_name=checkpoint_name, checkpoint_name=checkpoint_name,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment