Unverified Commit 5d937c57 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

TE/JAX Enhancement (#135)



* Rename enable_fp8 to is_fp8_enabled.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding an API to get an instance of  DelayedScaling which is set via fp8_autocast.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent a0f00654
......@@ -13,7 +13,7 @@ from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
......@@ -153,35 +153,39 @@ class TestFP8Helper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
self.assertFalse(FP8Helper.enable_fp8())
self.assertFalse(FP8Helper.is_fp8_enabled())
self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.interval == test.interval)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8())
self.assertFalse(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
......@@ -210,11 +214,8 @@ class TestFP8Functions(unittest.TestCase):
with maps.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertTrue(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state()
......@@ -200,7 +200,7 @@ class TestEncoderLayer:
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
......@@ -411,7 +411,7 @@ class TestDecoderLayer:
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
......
......@@ -128,8 +128,8 @@ class FP8Helper:
FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
......@@ -144,7 +144,7 @@ class FP8Helper:
FP8_2X_ACC_WGRAD: bool = True
@staticmethod
def enable_fp8():
def is_fp8_enabled():
"""
Indicate if fp8 training is enable or not.
"""
......@@ -182,7 +182,8 @@ class FP8Helper:
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1
FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
......@@ -385,3 +386,27 @@ def update_fp8_metas(state: Collection) -> Collection:
The collection with updated FP8 metas.
"""
return FP8Helper.update_fp8_metas(state)
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
:attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
Other parameters in recipe.DelayedScaling would be returned as the default
values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent"
return DelayedScaling(margin=FP8Helper.MARGIN,
interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo)
......@@ -412,7 +412,7 @@ class DenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel])
y = fp8_dot(fp8_gemm_package,
......@@ -537,7 +537,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"""
ln_output = None
fuse_layernorm = FP8Helper.enable_fp8(
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm:
......@@ -583,7 +583,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel])
......@@ -751,7 +751,7 @@ class LayerNormMLP(TransformerEngineBase):
"""
ln_output = None
fuse_layernorm = FP8Helper.enable_fp8(
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
use_fused_ln_mlp = fuse_layernorm \
......@@ -840,7 +840,7 @@ class LayerNormMLP(TransformerEngineBase):
def fp8_meta_generator():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None,
None)
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
......@@ -867,7 +867,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
......@@ -936,7 +936,7 @@ class LayerNormMLP(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage(
1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment