Unverified Commit a7bc7cf7 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support arbitrary dimensinos of fp8 meta. (#309)


Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent a7a1a070
......@@ -64,7 +64,7 @@ class TestFP8Helper(unittest.TestCase):
def select_amax(amaxes):
if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX:
return jnp.max(amaxes, axis=1, keepdims=True)
return jnp.max(amaxes, axis=-1, keepdims=True)
return amaxes[:, 0:1]
def get_fp8_scale(fp8_max, amax, scale):
......@@ -78,15 +78,16 @@ class TestFP8Helper(unittest.TestCase):
sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
scale_meta_shape = (num_of_meta, 1)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_amax_array1 = jax.random.uniform(key1, shape=amax_meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1),
jnp.ones(meta_shape))
jnp.ones(scale_meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_amax_array2 = jax.random.uniform(key2, shape=amax_meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2),
jnp.ones(meta_shape))
jnp.ones(scale_meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({
......@@ -94,14 +95,14 @@ class TestFP8Helper(unittest.TestCase):
"test_update_fp8_metas1": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
},
"test_update_fp8_metas2": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
}
}
})
......
......@@ -305,9 +305,9 @@ class FP8Helper:
fp8_max = fp8_meta_arrays[fp8_max_idx]
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True)
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
else:
amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1]
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
......@@ -366,14 +366,14 @@ def fp8_autocast(enabled: bool = False,
if fp8_recipe is None:
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_compute_algo in ["max", "most_recent"], (
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
assert fp8_recipe.amax_compute_algo in [
"max", "most_recent"
], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
assert fp8_recipe.scaling_factor_compute_algo is None, (
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
assert fp8_recipe.override_linear_precision == (False, False, False), (
"DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, (
"DelayedScaling reduce_amax should be enabled for TE/JAX.")
assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None:
sharding_resource = ShardingResource()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment