Unverified Commit 2da34d41 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Change scaling factor from E8M0 to E8M23 (#427)



* Change scaling factor from E8M0 to E8M23
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix formula
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 479dbb73
......@@ -880,11 +880,9 @@ def test_amax_and_scale_update():
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
exp = paddle.floor(paddle.log2(fp8_max / amax)) - margin
sf = paddle.round(2**paddle.abs(exp))
sf = (fp8_max / amax) / (2 ** margin)
sf = paddle.where(amax > 0.0, sf, scale)
sf = paddle.where(paddle.isfinite(amax), sf, scale)
sf = paddle.where(exp < 0, 1 / sf, sf)
return sf
scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
......
......@@ -115,8 +115,7 @@ class DelayedScaling:
.. code-block:: python
FP8_MAX = maximum_representable_value(fp8_format)
exp = get_exponent(FP8_MAX / amax) - margin
new_scaling_factor = 2.0 ^ exp
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)
* The scaling factor should always be a power of 2 to not introduce numerical
error during the conversion from FP8 to higher precision format.
......
......@@ -310,11 +310,9 @@ class FP8Helper:
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
sf = jnp.round(jnp.power(2, jnp.abs(exp)))
sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = jnp.where(exp < 0, 1 / sf, sf)
fp8_meta_arrays[fp8_scale_idx] = scale
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale
......@@ -426,11 +424,9 @@ def update_fp8_metas(state: Collection) -> Collection:
.. code-block:: python
exp = floor(log2(fp8_max / amax)) - margin
sf = round(power(2, abs(exp)))
sf = (fp8_max / amax) / (2 ^ margin)
sf = sf if amax > 0.0, else original_scale
sf = sf if isfinite(amax), else original_scale)
updated_scale = 1/sf if exp < 0, else sf
updated_scale = sf if isfinite(amax), else original_scale)
updated_scale_inv = 1/updated_scale
Collection = [dict, flax.core.frozen_dict.FrozenDict]
......
......@@ -1032,11 +1032,8 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float exp = floor(log2(fp8_max / amax[idx])) - margin;
float sf = round(powf(2.0f, abs(exp)));
float scale_reg = scale[idx];
sf = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale_reg;
scale_reg = exp < 0.0f ? 1 / sf : sf;
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
......
......@@ -538,12 +538,9 @@ def _default_sf_compute(
margin: int,
) -> torch.Tensor:
"""Default function to convert amax to scaling factor."""
exp = torch.floor(torch.log2(fp8_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
sf = (fp8_max / amax) / (2 ** margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
return sf
......
......@@ -157,11 +157,9 @@ def get_fp8_recipe():
def _default_sf_compute(amax, scale, fp8_max, margin):
"""Default function to convert amax to scaling factor."""
exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin
sf = tf.math.round(tf.math.pow(2.0, tf.math.abs(exp)))
sf = (fp8_max / amax) / (2 ** margin)
sf = tf.where(amax > 0.0, sf, scale)
sf = tf.where(tf.math.is_finite(amax), sf, scale)
sf = tf.where(exp < 0, 1.0 / sf, sf)
return sf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment