"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "92f431bfee3bebbf49d1cc6f6bc37796bffd8bb7"
Unverified Commit dfeef1a2 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Address tolerance check for current scaling dact dbias (#2211)



Address tolerance check for current scaling dact
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent a91e4585
......@@ -780,9 +780,15 @@ class TestFusedQuantize:
assert_allclose(te_output.data, jax_output.data)
if is_dbias:
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type == ("squared_relu",)
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment