Unverified Commit 4ff3eed1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348)



* Add test to check jaxpr that amax is reused for nvfp4 recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Move test to test_helper.py and rename file
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b14a3b62
......@@ -45,7 +45,6 @@ from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.common import recipe
GEMM_CASES = [
(256, 256, 512),
......
......@@ -11,7 +11,7 @@ import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from utils import assert_allclose
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.common.recipe import (
DelayedScaling,
MXFP8BlockScaling,
......@@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
get_supported_quantization_recipes,
is_scaling_mode_supported,
ScalingMode,
update_collections,
......@@ -32,11 +33,15 @@ from transformer_engine.jax.quantize import (
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
from transformer_engine.jax import flax as te_flax
import transformer_engine.jax as te
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = get_supported_quantization_recipes()
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
......@@ -253,3 +258,63 @@ class TestFP8Functions(unittest.TestCase):
self._compare_nvfp4_scaling_quantizers(bs)
self._check_default_state()
class TestJaxprAndHlo:
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
@pytest_parametrize_wrapper(
"quantization_recipe",
[
quantization_recipe
for quantization_recipe in SUPPORTED_RECIPES
if isinstance(quantization_recipe, NVFP4BlockScaling)
],
)
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
model = te_flax.LayerNormMLP(
layernorm_type="rmsnorm",
return_layernorm_output=False,
intermediate_dropout_rate=0.0,
dtype=jnp.bfloat16,
)
var_collect = model.init(
jax.random.PRNGKey(0),
jnp.ones((128, 128), dtype=jnp.bfloat16),
)
def loss_fn(x, rngs):
return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
rht_amax_eqns = [
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
]
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
def assert_param(index, tensor_name, expected_value: bool):
if expected_value:
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
" reuse of amax as this tensor does not have a previous operation to fuse"
" with"
)
else:
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
" reuse of amax"
)
assert_param(0, "fwd ln+q", False)
assert_param(1, "fwd act+q", False)
# No previous op before incoming dgrad in the backward so amax is not reused
assert_param(2, "bwd dgrad", True)
assert_param(3, "bwd dact+q", False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment