Unverified Commit 757fd1cf authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix Flax variable creation when quantizers are created directly from a recipe (#2079)



Fix flax variables when creating quantizers directly from a recipe
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 6ba98d43
......@@ -15,6 +15,8 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from transformer_engine.common import recipe
from ..dense import dense
from ..layernorm import canonicalize_norm_type
......@@ -366,7 +368,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment