Unverified Commit 4ff3eed1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348)



* Add test to check jaxpr that amax is reused for nvfp4 recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move test to test_helper.py and rename file
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b14a3b62
...@@ -45,7 +45,6 @@ from transformer_engine.jax.quantize import helper ...@@ -45,7 +45,6 @@ from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.common import recipe
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
......
...@@ -11,7 +11,7 @@ import jax.numpy as jnp ...@@ -11,7 +11,7 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
DelayedScaling, DelayedScaling,
MXFP8BlockScaling, MXFP8BlockScaling,
...@@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format as FP8Format ...@@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_quantize_config,
get_supported_quantization_recipes,
is_scaling_mode_supported, is_scaling_mode_supported,
ScalingMode, ScalingMode,
update_collections, update_collections,
...@@ -32,11 +33,15 @@ from transformer_engine.jax.quantize import ( ...@@ -32,11 +33,15 @@ from transformer_engine.jax.quantize import (
from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase from transformer_engine.jax.flax.module import TransformerEngineBase
from transformer_engine.jax import flax as te_flax
import transformer_engine.jax as te
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = get_supported_quantization_recipes()
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries.""" """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
...@@ -253,3 +258,63 @@ class TestFP8Functions(unittest.TestCase): ...@@ -253,3 +258,63 @@ class TestFP8Functions(unittest.TestCase):
self._compare_nvfp4_scaling_quantizers(bs) self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state() self._check_default_state()
class TestJaxprAndHlo:
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
@pytest_parametrize_wrapper(
"quantization_recipe",
[
quantization_recipe
for quantization_recipe in SUPPORTED_RECIPES
if isinstance(quantization_recipe, NVFP4BlockScaling)
],
)
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
model = te_flax.LayerNormMLP(
layernorm_type="rmsnorm",
return_layernorm_output=False,
intermediate_dropout_rate=0.0,
dtype=jnp.bfloat16,
)
var_collect = model.init(
jax.random.PRNGKey(0),
jnp.ones((128, 128), dtype=jnp.bfloat16),
)
def loss_fn(x, rngs):
return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
rht_amax_eqns = [
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
]
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
def assert_param(index, tensor_name, expected_value: bool):
if expected_value:
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
" reuse of amax as this tensor does not have a previous operation to fuse"
" with"
)
else:
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
" reuse of amax"
)
assert_param(0, "fwd ln+q", False)
assert_param(1, "fwd act+q", False)
# No previous op before incoming dgrad in the backward so amax is not reused
assert_param(2, "bwd dgrad", True)
assert_param(3, "bwd dact+q", False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment