Unverified Commit bf3e1715 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update reference scale calculation in TensorFlow test (#463)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 38b85c35
......@@ -34,11 +34,9 @@ def get_fp8_recipe(override_wgrad=False):
def compute_scale(amax, scale, fp8_max, margin):
"""Default function to convert amax to scaling factor."""
exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin
sf = tf.math.round(tf.math.pow(2., tf.math.abs(exp)))
sf = (fp8_max / amax) / (2 ** margin)
sf = tf.where(amax > 0.0, sf, scale)
sf = tf.where(tf.math.is_finite(amax), sf, scale)
sf = tf.where(exp < 0, 1.0 / sf, sf)
return sf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment