"docs/vscode:/vscode.git/clone" did not exist on "e2fb71ec9f2c3168ba8614408fa807a5f65707c5"
Unverified Commit 76e1af33 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Add assertion message to amax -> scale computation (#2263)



assertion check
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 8c364b4d
...@@ -67,7 +67,7 @@ def compute_scale_from_amax( ...@@ -67,7 +67,7 @@ def compute_scale_from_amax(
sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale)
assert sf.shape == (1,) assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
return sf return sf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment