Unverified Commit 5d937c57 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

TE/JAX Enhancement (#135)



* Rename enable_fp8 to is_fp8_enabled.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding an API to get an instance of  DelayedScaling which is set via fp8_autocast.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent a0f00654
...@@ -13,7 +13,7 @@ from jax.experimental import maps ...@@ -13,7 +13,7 @@ from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported from utils import assert_allclose, is_fp8_supported
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType from transformer_engine.jax.sharding import MajorShardingType
...@@ -153,35 +153,39 @@ class TestFP8Helper(unittest.TestCase): ...@@ -153,35 +153,39 @@ class TestFP8Helper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase): class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self): def _check_defult_state(self):
self.assertFalse(FP8Helper.enable_fp8()) self.assertFalse(FP8Helper.is_fp8_enabled())
self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE) self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.interval == test.interval)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self): def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8()) self.assertFalse(FP8Helper.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8()) self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertEqual(FP8Helper.MARGIN, ds.margin) self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8()) self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertEqual(FP8Helper.MARGIN, ds.margin) self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
...@@ -210,11 +214,8 @@ class TestFP8Functions(unittest.TestCase): ...@@ -210,11 +214,8 @@ class TestFP8Functions(unittest.TestCase):
with maps.Mesh(devices, ('dp', 'tp')): with maps.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs: for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr): with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.enable_fp8()) self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertEqual(FP8Helper.MARGIN, ds.margin) self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst) self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state() self._check_defult_state()
...@@ -200,7 +200,7 @@ class TestEncoderLayer: ...@@ -200,7 +200,7 @@ class TestEncoderLayer:
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params, has_aux=False)(inputs, test_masks, test_params,
...@@ -411,7 +411,7 @@ class TestDecoderLayer: ...@@ -411,7 +411,7 @@ class TestDecoderLayer:
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs) ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params, has_aux=False)(inputs, test_masks, test_params,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast, update_collections, update_fp8_metas from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
......
...@@ -128,8 +128,8 @@ class FP8Helper: ...@@ -128,8 +128,8 @@ class FP8Helper:
FWD_DTYPE: DType = DType.kFloat8E4M3 FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2 BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1 UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_LEN: int = 1 AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
NUM_META_PER_GEMM: int = 3 NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0 INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1 KERNEL_META_IDX_PER_GEMM: int = 1
...@@ -144,7 +144,7 @@ class FP8Helper: ...@@ -144,7 +144,7 @@ class FP8Helper:
FP8_2X_ACC_WGRAD: bool = True FP8_2X_ACC_WGRAD: bool = True
@staticmethod @staticmethod
def enable_fp8(): def is_fp8_enabled():
""" """
Indicate if fp8 training is enable or not. Indicate if fp8 training is enable or not.
""" """
...@@ -182,7 +182,8 @@ class FP8Helper: ...@@ -182,7 +182,8 @@ class FP8Helper:
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_LEN = 1 FP8Helper.AMAX_HISTORY_LEN = 1024
FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
@staticmethod @staticmethod
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray: def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
...@@ -385,3 +386,27 @@ def update_fp8_metas(state: Collection) -> Collection: ...@@ -385,3 +386,27 @@ def update_fp8_metas(state: Collection) -> Collection:
The collection with updated FP8 metas. The collection with updated FP8 metas.
""" """
return FP8Helper.update_fp8_metas(state) return FP8Helper.update_fp8_metas(state)
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`interval`,
:attr:`amax_history_len` and :attr:`amax_compute_algo` via fp8_autocast.
Other parameters in recipe.DelayedScaling would be returned as the default
values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = "max" if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX \
else "most_recent"
return DelayedScaling(margin=FP8Helper.MARGIN,
interval=FP8Helper.UPDATE_FP8META_INTERVAL,
fp8_format=FP8Helper.FP8_FORMAT,
amax_history_len=FP8Helper.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo)
...@@ -412,7 +412,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -412,7 +412,7 @@ class DenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \ fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel]) TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel])
y = fp8_dot(fp8_gemm_package, y = fp8_dot(fp8_gemm_package,
...@@ -537,7 +537,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -537,7 +537,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
""" """
ln_output = None ln_output = None
fuse_layernorm = FP8Helper.enable_fp8( fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm: if self.enable_layernorm:
...@@ -583,7 +583,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -583,7 +583,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = \ fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel]) TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel])
...@@ -751,7 +751,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -751,7 +751,7 @@ class LayerNormMLP(TransformerEngineBase):
""" """
ln_output = None ln_output = None
fuse_layernorm = FP8Helper.enable_fp8( fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm ) and not self.return_layernorm_output and self.enable_layernorm
use_fused_ln_mlp = fuse_layernorm \ use_fused_ln_mlp = fuse_layernorm \
...@@ -840,7 +840,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -840,7 +840,7 @@ class LayerNormMLP(TransformerEngineBase):
def fp8_meta_generator(): def fp8_meta_generator():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None,
None) None)
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \ fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm) TransformerEngineBase.get_fp8_metas(num_of_gemm)
return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
...@@ -867,7 +867,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -867,7 +867,7 @@ class LayerNormMLP(TransformerEngineBase):
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage( fp8_gemm_package = FP8GemmPackage(
1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :], 1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :], fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
...@@ -936,7 +936,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -936,7 +936,7 @@ class LayerNormMLP(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8(): if FP8Helper.is_fp8_enabled():
fp8_gemm_package = FP8GemmPackage( fp8_gemm_package = FP8GemmPackage(
1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :], 1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :], fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment