Unverified Commit 6280dc7a authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

JAX small changes (#251)



* Use the same default in the function to what the class default.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Assert instead of silently ignoring not supported variation. Small doc correction, amax_compute_algo is partially supported.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Fix line lenght to fix the CI.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

* grammar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Clarify that it is only TE/JAX that don't support that faeture.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

* Update the test following the change in default value
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 207b231e
......@@ -14,7 +14,7 @@ from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource
......@@ -63,6 +63,11 @@ class TestFP8Helper(unittest.TestCase):
num_of_gemm = 10
num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm
def select_amax(amaxes):
if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX:
return jnp.max(amaxes, axis=1, keepdims=True)
return amaxes[:, 0:1]
def get_fp8_scale(fp8_max, amax, scale):
fp8_max = np.array(fp8_max)
amax = np.array(amax)
......@@ -77,11 +82,11 @@ class TestFP8Helper(unittest.TestCase):
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1[:, 0:1],
fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1),
jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2[:, 0:1],
fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2),
jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
......
......@@ -192,7 +192,7 @@ class FP8Helper:
fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1,
amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT) -> None:
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX) -> None:
"""
Initialize the FP8 meta
"""
......@@ -346,9 +346,11 @@ def fp8_autocast(enabled: bool = False,
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even if set.
We only support :attr:`margin`, :attr:`fp8_format`,
:attr:`interval`, :attr:`amax_history_len` and
:attr:`amax_compute_algo`(with value 'max' and 'most_recent')
in recipe.DelayedScaling currently. Other parameters in
recipe.DelayedScaling will trigger an assertion.
Parameters
----------
......@@ -358,11 +360,21 @@ def fp8_autocast(enabled: bool = False,
Recipe used for FP8 training.
sharding_resource: ShardingResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created.
If set to None, then no data or tensor parallelism will be used.
"""
if fp8_recipe is None:
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_compute_algo in ["max", "most_recent"], (
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX.")
assert fp8_recipe.scaling_factor_compute_algo is None, (
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX.")
assert fp8_recipe.override_linear_precision == (False, False, False), (
"DelayedScaling override_linear_precision isn't supported by TE/JAX.")
assert fp8_recipe.reduce_amax, (
"DelayedScaling reduce_amax should be enabled for TE/JAX.")
if sharding_resource is None:
sharding_resource = ShardingResource()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment