Unverified Commit a7bc7cf7 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support arbitrary dimensinos of fp8 meta. (#309)


Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent a7a1a070
...@@ -64,7 +64,7 @@ class TestFP8Helper(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestFP8Helper(unittest.TestCase):
def select_amax(amaxes): def select_amax(amaxes):
if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX: if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX:
return jnp.max(amaxes, axis=1, keepdims=True) return jnp.max(amaxes, axis=-1, keepdims=True)
return amaxes[:, 0:1] return amaxes[:, 0:1]
def get_fp8_scale(fp8_max, amax, scale): def get_fp8_scale(fp8_max, amax, scale):
...@@ -78,15 +78,16 @@ class TestFP8Helper(unittest.TestCase): ...@@ -78,15 +78,16 @@ class TestFP8Helper(unittest.TestCase):
sf = np.where(np.isfinite(amax), sf, scale) sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf) return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN) amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
scale_meta_shape = (num_of_meta, 1)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta) fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape) fp8_amax_array1 = jax.random.uniform(key1, shape=amax_meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1), fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1),
jnp.ones(meta_shape)) jnp.ones(scale_meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1 fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape) fp8_amax_array2 = jax.random.uniform(key2, shape=amax_meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2), fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2),
jnp.ones(meta_shape)) jnp.ones(scale_meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2 fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({ state = flax.core.frozen_dict.FrozenDict({
...@@ -94,14 +95,14 @@ class TestFP8Helper(unittest.TestCase): ...@@ -94,14 +95,14 @@ class TestFP8Helper(unittest.TestCase):
"test_update_fp8_metas1": { "test_update_fp8_metas1": {
FP8Helper.FP8_MAX_NAME: fp8_max_array, FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array1, FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape), FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape) FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
}, },
"test_update_fp8_metas2": { "test_update_fp8_metas2": {
FP8Helper.FP8_MAX_NAME: fp8_max_array, FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array2, FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape), FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape) FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
} }
} }
}) })
......
...@@ -305,9 +305,9 @@ class FP8Helper: ...@@ -305,9 +305,9 @@ class FP8Helper:
fp8_max = fp8_meta_arrays[fp8_max_idx] fp8_max = fp8_meta_arrays[fp8_max_idx]
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True) amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
else: else:
amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1] amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx] scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
...@@ -366,14 +366,14 @@ def fp8_autocast(enabled: bool = False, ...@@ -366,14 +366,14 @@ def fp8_autocast(enabled: bool = False,
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = DelayedScaling() fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_compute_algo in ["max", "most_recent"], ( assert fp8_recipe.amax_compute_algo in [
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.") "max", "most_recent"
], ("DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
assert fp8_recipe.scaling_factor_compute_algo is None, ( assert fp8_recipe.scaling_factor_compute_algo is None, (
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.") "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
assert fp8_recipe.override_linear_precision == (False, False, False), ( assert fp8_recipe.override_linear_precision == (False, False, False), (
"DelayedScaling override_linear_precision isn't supported by TE/JAX.") "DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, ( assert fp8_recipe.reduce_amax, ("DelayedScaling reduce_amax should be enabled for TE/JAX.")
"DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None: if sharding_resource is None:
sharding_resource = ShardingResource() sharding_resource = ShardingResource()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment