Unverified Commit bc9d57a3 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

Add TE/JAX high-level modules, unittests and examples (#54)



* add transformer module , unittests and examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update tests/jax/test_sharding.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/transformer.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint: disable=line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove pylint: disable=too-many-func-args
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Fix the wrong broadcasting dim to dropout masks when enable transpose_bs.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Enable 2xACC for WGRAD and DGRAD by default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename LayerNormMlpBlock as LayerNormMLP
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor to avoid line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename amax_history_size to amax_history_len
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* align dropout mask to TE/PyTorch as default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* enlarge atol for decoder unittests

Two decoder unittests can pass in old JAX container(e.g., 23.02)
but can't in latest container (devel).

1. The actual(-0.020264) and desired(-0.020386) are very close.
2. The TE kernels are not changed, the diff should come from
   new codegen behavior of XLA.

Thus, it is a common floating-point accumulated error.
Enlarge atol to avoid unittest failures.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Adding Amax History Support

1. hide amax update in custom_vjp
2. replace amax indexing with roll(using circular buffer)
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* move kernel_init to __post_init__
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor encoder examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove envvar regarding 2xACC
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove unused import
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5925d444
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with BF16 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
optimizer = optax.sgd(0.001, 0.9)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with FP8 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from cuda import cudart
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.common.recipe import DelayedScaling
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def gpu_has_fp8():
"""GPU arch has to support FP8"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
others = FP8Helper.update_fp8_metas(grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
if gpu_has_fp8() is False:
print("GPU doesn't support FP8")
return
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
optimizer = optax.sgd(0.001, 0.9)
with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(fp8_format=FP8Format.HYBRID)):
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
assert "fp8" in str(jax.make_jaxpr(jitted_train_step)(inputs, state, variables))
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
......@@ -6,3 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
pytest -Wignore -v $TE_PATH/examples/jax
......@@ -60,7 +60,7 @@ class TestFP8Dot:
def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
......@@ -68,7 +68,7 @@ class TestFP8Dot:
# y = input, matrix 2d (weight)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)))
return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
......@@ -84,13 +84,13 @@ class TestFP8Dot:
def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type))
return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
......@@ -104,13 +104,13 @@ class TestFP8Dot:
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))
primitive_out = fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))
ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out)
......@@ -128,7 +128,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
......@@ -136,12 +136,12 @@ class TestFP8Dot:
# calculate amax
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type)
primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
# calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type)
primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
ref_out = jnp.dot(a, b)
ref_out = ref_out.astype(jnp.float32)
......@@ -158,13 +158,13 @@ class TestFP8Dot:
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)))
return jnp.mean(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
def ref_func(x, y):
return jnp.mean(jnp.dot(x, y))
......@@ -193,7 +193,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
......@@ -201,7 +201,7 @@ class TestFP8Dot:
def primitive_func(x, y, metas):
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type))
return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
def ref_func(x, y):
return jnp.sum(jnp.dot(x, y))
......@@ -232,13 +232,13 @@ class TestFP8Dot:
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None), ((2, 3), (0, 1))))
return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None), ((2, 3), (0, 1))))
def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
......@@ -266,7 +266,7 @@ class TestFP8Dot:
s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_SIZE),
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
......@@ -283,7 +283,6 @@ class TestFP8Dot:
ln_s,
None,
"rmsnorm",
0,
*compute_type,
activations=activations))
......@@ -305,7 +304,6 @@ class TestFP8Dot:
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
amax_history_idx: int,
fwd_dtype,
bwd_dtype,
epsilon=1e-6,
......@@ -323,7 +321,6 @@ class TestFP8Dot:
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = fp8_dot(fp8_gemm_1_pkg,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
......@@ -341,7 +338,6 @@ class TestFP8Dot:
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = fp8_dot(fp8_gemm_2_pkg,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
......@@ -350,7 +346,7 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, metas):
return jnp.mean(
fp8_ln_mlp_py(x, ln_s, y, z, *metas, 0, *compute_type, activations=activations))
fp8_ln_mlp_py(x, ln_s, y, z, *metas, *compute_type, activations=activations))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3)))
......
......@@ -27,12 +27,12 @@ class TestFP8Helper(unittest.TestCase):
margin = 5.0
fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10
amax_history_size = 10
amax_history_len = 10
FP8Helper.initialize(margin=margin,
fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval,
amax_history_size=amax_history_size)
amax_history_len=amax_history_len)
self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
......@@ -46,15 +46,15 @@ class TestFP8Helper(unittest.TestCase):
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual(
FP8Helper.AMAX_HISTORY_SIZE, amax_history_size,
f"FP8Helper.AMAX_HISTORY_SIZE initialization failed, should be {amax_history_size}"
f" but got {FP8Helper.AMAX_HISTORY_SIZE}.")
FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_size=5)
FP8Helper.initialize(margin=3.0, amax_history_len=3)
seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
......@@ -72,13 +72,15 @@ class TestFP8Helper(unittest.TestCase):
sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_SIZE)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1, jnp.ones(meta_shape))
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2, jnp.ones(meta_shape))
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({
......@@ -156,6 +158,9 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8())
......@@ -167,7 +172,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
......@@ -176,16 +181,12 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(amax_history_len=2)
with self.assertRaises(AssertionError):
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(amax_history_len=2)):
pass
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
......@@ -213,7 +214,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state()
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import tempfile
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import DenseGeneral
from transformer_engine.jax.fp8 import FP8Helper
from utils import is_fp8_supported
class MLPNN(nn.Module):
use_fp8_dense: bool = True
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=512)(x)
x = nn.relu(x)
features = [256, 256, 128]
for feature in features:
x = DenseGeneral(features=feature, transpose_batch_sequence=False,
dtype=jnp.bfloat16, use_bias=True)(x) \
if self.use_fp8_dense else nn.Dense(features=feature)(x)
x = nn.relu(x)
x = nn.Dense(features=10, use_bias=True)(x)
return x
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum, use_fp8_dense):
"""Creates initial `TrainState`."""
cnn = MLPNN(use_fp8_dense=use_fp8_dense)
variables = cnn.init(rng, jnp.ones([32, 28, 28, 1]))
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'],
tx=tx), variables
@partial(jax.jit, static_argnums=(3,))
def train_step(state, others, batch, use_fp8_dense):
"""Train for a single step."""
def loss_fn(collections):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(others)
state = state.apply_gradients(grads=grads['params'])
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics, grads
def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for idx, perm in enumerate(perms):
idx = idx + 1
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics, grads = train_step(state, variables, batch, use_fp8_dense)
updated_coll = {'params': state.params}
if use_fp8_dense:
updated_coll[FP8Helper.FP8_COLLECTION_NAME] \
= grads[FP8Helper.FP8_COLLECTION_NAME]
variables = FP8Helper.update_collections(updated_coll, variables)
batch_metrics.append(metrics)
if use_fp8_dense:
variables = FP8Helper.update_fp8_metas(variables)
return state, variables
@partial(jax.jit, static_argnums=(2,))
def eval_step(variables, batch, use_fp8_dense):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image'])
return compute_metrics(logits=logits, labels=batch['label'])
def eval_model(variables, test_ds, batch_size, use_fp8_dense):
test_ds_size = len(test_ds['image'])
steps_per_epoch = test_ds_size // batch_size
perms = np.arange(0, test_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
total_summary = {'correct': 0, 'loss': 0, 'total': 0}
for _, perm in enumerate(perms):
batch = {k: v[perm, ...] for k, v in test_ds.items()}
metrics = eval_step(variables, batch, use_fp8_dense)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
total_summary['correct'] += summary['accuracy'] * batch_size
total_summary['loss'] += summary['loss'] * batch_size
total_summary['total'] += batch_size
return total_summary['loss']/total_summary['total'], \
total_summary['correct']/total_summary['total']
class TestMnist(unittest.TestCase):
def setUp(self) -> None:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
self.learning_rate = 0.1
self.momentum = 0.9
self.num_epochs = 5
self.batch_size = 512
self.train_ds, self.test_ds = get_datasets()
self.margin = 0.0
self.num_fp8_layers = 3
self.fp8_meta_update_interval = 1
self.temp_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
self.fp8_ckpt_path = self.temp_file.name
self.seed = 0
acc_bfp16_ = self._mnist_baseline_runner()
acc_rtol = 0.005
self.target_accuracy = acc_bfp16_ * (1. - acc_rtol)
def tearDown(self):
self.temp_file.close()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_e4m3(self):
self._mnist_test_runner(FP8Format.E4M3)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_hybrid(self):
self._mnist_test_runner(FP8Format.HYBRID)
# Skip for now due to lack bf16 in TE.Format
# def test_mnist_bfloa16(self):
# self._mnist_test_runner(FP8Format.BFLOAT16)
def _mnist_baseline_runner(self):
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False)
del init_rng
_, accuracy = self._train_model(state, variables, self.num_epochs, rng, False)
return accuracy
def _mnist_test_runner(self, fp8_format):
FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format)
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True)
del init_rng
_, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True)
self.assertGreater(
test_accuracy, self.target_accuracy,
f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. "
f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}")
FP8Helper.finalize()
def _train_model(self, state, variables, epochs, rng, use_fp8_dense):
max_test_acc = 0.0
for _ in range(0, epochs):
rng, input_rng = jax.random.split(rng)
state, variables = train_epoch(state, variables, self.train_ds, self.batch_size,
input_rng, use_fp8_dense)
_, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense)
max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc
return state, max_test_acc
if __name__ == '__main__':
unittest.main()
......@@ -7,6 +7,7 @@ import numpy as np
import pytest
from jax.experimental import maps
from transformer_engine.jax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
......@@ -14,6 +15,7 @@ from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
from utils import is_devices_enough
def _get_sharding_resource(mesh_names, sharding_type):
......@@ -47,14 +49,25 @@ SRS = [
]
def is_devices_enough():
return len(jax.devices()) >= DEVICE_COUNT
class TestShardingSideAPI:
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
@pytest.mark.parametrize('sr', SRS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_infer_major_sharding_type(
self,
mesh_shape, # pylint: disable=unused-argument
......@@ -94,7 +107,7 @@ class TestShardingMetaGenerator:
MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping):
......@@ -145,7 +158,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
......@@ -240,7 +253,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
@pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim):
......
This diff is collapsed.
......@@ -2,5 +2,8 @@
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast
from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import RelativePositionBiases, TransformerLayer, TransformerLayerType
from .sharding import ShardingResource
......@@ -276,7 +276,7 @@ class CastTransposePrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype),
ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -427,7 +427,7 @@ class GatedGeluFp8Primitive(BasePrimitive):
batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen
out_types = [
ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -599,7 +599,7 @@ class DgatedGeluCastTransposePrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype),
ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [inputs, gelu_inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -915,7 +915,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_mu_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
......@@ -1196,7 +1196,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......@@ -1369,7 +1369,7 @@ class QuantizePrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(ir_out_shape, ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
......@@ -21,7 +21,6 @@ jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
......@@ -45,7 +44,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax,
scale,
scale_inv,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
......@@ -77,7 +75,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
......@@ -93,18 +90,17 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, amax_history_idx: int, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType, dp_axis_name: str, tp_axis_name: str):
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
......@@ -121,7 +117,6 @@ def _fp8_dot_fwd(
amax,
scale,
scale_inv,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
......@@ -139,6 +134,8 @@ def _fp8_dot_fwd(
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
......@@ -170,7 +167,6 @@ def _fp8_dot_fwd(
def _fp8_dot_bwd(
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
......@@ -202,9 +198,9 @@ def _fp8_dot_bwd(
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0])
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
......@@ -4,12 +4,14 @@
"""
Helper module for fp8 meta management
"""
import os
from contextlib import contextmanager
from typing import Optional, Union, Dict, List, Tuple
from flax.core.frozen_dict import FrozenDict
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
......@@ -110,6 +112,12 @@ class FP8GemmPackage:
return self._scale_inv
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
MAX = "max"
MOST_RECENT = "most_recent"
class FP8Helper:
"""
FP8 helper to manage the FP8 meta
......@@ -120,7 +128,8 @@ class FP8Helper:
FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_SIZE: int = 1
AMAX_HISTORY_LEN: int = 1
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT
NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
......@@ -130,12 +139,9 @@ class FP8Helper:
FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
FP8_MAX_NAME: str = "fp8_max"
FP8_2X_ACC_FPROP_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_FPROP"
FP8_2X_ACC_DGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_DGRAD"
FP8_2X_ACC_WGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_WGRAD"
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
FP8_2X_ACC_DGRAD: bool = True
FP8_2X_ACC_WGRAD: bool = True
@staticmethod
def enable_fp8():
......@@ -148,7 +154,8 @@ class FP8Helper:
def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1,
amax_history_size: int = 1) -> None:
amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT) -> None:
"""
Initialize the FP8 meta
"""
......@@ -158,13 +165,11 @@ class FP8Helper:
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_SIZE = amax_history_size
FP8Helper.FP8_2X_ACC_FPROP = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_FPROP_ENV_VAR_NAME, False)))
FP8Helper.FP8_2X_ACC_DGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_DGRAD_ENV_VAR_NAME, False)))
FP8Helper.FP8_2X_ACC_WGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_WGRAD_ENV_VAR_NAME, False)))
FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
FP8Helper.FP8_2X_ACC_FPROP = False
FP8Helper.FP8_2X_ACC_DGRAD = True
FP8Helper.FP8_2X_ACC_WGRAD = True
@staticmethod
def finalize() -> None:
......@@ -177,10 +182,19 @@ class FP8Helper:
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_SIZE = 1
FP8Helper.AMAX_HISTORY_LEN = 1
@staticmethod
def update_collections(new: Collection, original: Collection) -> None:
def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
"""
Update the collections
"""
......@@ -244,7 +258,10 @@ class FP8Helper:
fp8_scale_inv_idx = fp8_scale_idx + 1
fp8_max = fp8_meta_arrays[fp8_max_idx]
amax = fp8_meta_arrays[fp8_amax_idx]
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True)
else:
amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1]
scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
......@@ -262,7 +279,7 @@ class FP8Helper:
def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None:
"""
r"""
Context manager for FP8 usage.
.. code-block:: python
......@@ -284,15 +301,15 @@ def fp8_autocast(enabled: bool = False,
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len=1` in recipe.DelayedScaling currently. Other parameters
:attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set.
Parameters
----------
enabled: bool, default = False
whether or not to enable fp8
whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training.
recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None
specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created.
......@@ -300,19 +317,70 @@ def fp8_autocast(enabled: bool = False,
if fp8_recipe is None:
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_history_len == 1, \
"It only support amax_history_len == 1 for now."
if sharding_resource is None:
sharding_resource = ShardingResource()
try:
with global_shard_guard(sharding_resource):
if enabled:
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max':
amax_compute_algo = AmaxComputeAlgo.MAX
FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval,
amax_history_size=fp8_recipe.amax_history_len)
amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo)
yield
finally:
FP8Helper.finalize()
# Function Wrappers
def update_collections(new: Collection, original: Collection) -> Collection:
r"""
A helper to update Flax's Collection. Collection is a union type of dict and
Flax's FrozenDict.
Collection = [dict, FrozenDict]
Parameters
----------
new: Collection
A collection that includes new data.
original: Collection
The base collection.
Returns
-------
outputs : Collection
The updated collection.
"""
return FP8Helper.update_collections(new, original)
def update_fp8_metas(state: Collection) -> Collection:
r"""
Calculate new fp8 scales and its inverse via the followed formula
`exp` = floor(log2(`fp8_max` / `amax`)) - `margin`
`sf` = round(power(2, abs(exp)))
`sf` = `sf` if `amax` > 0.0, else original_scale
`sf` = `sf` if isfinite(`amax`), else original_scale)
`updated_scale` = `1/sf` if exp < 0, else `sf`
`updated_scale_inv` = `1/updated_scale`
Collection = [dict, FrozenDict]
Parameters
----------
state: Collection
A collection that includes FP8 metas.
Returns
-------
outputs : Collection
The collection with updated FP8 metas.
"""
return FP8Helper.update_fp8_metas(state)
......@@ -132,7 +132,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
gamma: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
......@@ -167,7 +166,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
scale,
scale_inv,
layernorm_type,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
......@@ -213,7 +211,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
......@@ -233,7 +230,7 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_fp8_dot(inputs: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -243,7 +240,6 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
......@@ -252,9 +248,9 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, amax_history_idx, fwd_dtype,
bwd_dtype, contracting_dims, sharding_type, dp_axis_name,
tp_axis_name, epsilon)
scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, sharding_type, dp_axis_name, tp_axis_name,
epsilon)
return output
......@@ -268,7 +264,6 @@ def _layernorm_fp8_dot_fwd(
scale,
scale_inv,
layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
......@@ -286,6 +281,8 @@ def _layernorm_fp8_dot_fwd(
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
......@@ -337,7 +334,6 @@ def _layernorm_fp8_dot_fwd(
def _layernorm_fp8_dot_bwd(
layernorm_type,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
......@@ -386,9 +382,9 @@ def _layernorm_fp8_dot_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0])
amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
......@@ -101,7 +101,6 @@ def fp8_ln_mlp(
ln_scale: jnp.ndarray,
ln_bias: jnp.ndarray,
layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
epsilon: float = 1e-6,
......@@ -130,8 +129,8 @@ def fp8_ln_mlp(
assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, amax_history_idx, activations, epsilon, fwd_dtype,
bwd_dtype, contracting_dims, major_sharding_type, "", "")
scale_inv, layernorm_type, activations, epsilon, fwd_dtype, bwd_dtype,
contracting_dims, major_sharding_type, "", "")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -177,7 +176,6 @@ def fp8_ln_mlp(
partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
activations=activations,
epsilon=epsilon,
fwd_dtype=fwd_dtype,
......@@ -198,10 +196,10 @@ def fp8_ln_mlp(
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, amax_history_idx: int,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
......@@ -215,7 +213,6 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
scale,
scale_inv,
layernorm_type,
amax_history_idx,
activations,
epsilon,
fwd_dtype,
......@@ -238,7 +235,6 @@ def _fp8_mlp_fwd(
scale,
scale_inv,
layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
activations,
epsilon,
fwd_dtype,
......@@ -266,6 +262,8 @@ def _fp8_mlp_fwd(
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm1_input_idx]
......@@ -335,7 +333,6 @@ def _fp8_mlp_fwd(
def _fp8_mlp_bwd(
layernorm_type,
amax_history_idx,
activations, # pylint: disable=unused-argument
epsilon,
fwd_dtype,
......@@ -405,12 +402,12 @@ def _fp8_mlp_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None
amax = amax.at[gemm1_input_idx, amax_history_idx].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, amax_history_idx].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, amax_history_idx].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, amax_history_idx].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, amax_history_idx].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, amax_history_idx].set(grad_amax[0])
amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP):
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
......
This diff is collapsed.
......@@ -694,9 +694,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({
dp_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape])
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis},
input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self,
input_shape: Tuple,
......@@ -764,9 +765,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({
tp_dim: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape])
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis},
input_new_shapes, [input_shape])
@staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple,
......@@ -794,7 +796,7 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape]
out_axes = in_axes
out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis,
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
from .cpp_extensions import scaled_softmax_fwd
from .cpp_extensions import scaled_softmax_bwd
from .cpp_extensions import scaled_masked_softmax_fwd
from .cpp_extensions import scaled_masked_softmax_bwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_fwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class SoftmaxType(Enum):
"""SoftmaxType."""
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int,
k_seqlen: int, dtype: jnp.dtype):
"""check softmax available"""
if softmax_type is SoftmaxType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype)
raise NotImplementedError
def softmax(inputs: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
tp_dim_index: int = 1):
"""
Softmax wrapper
"""
assert dp_dim_index == 0, \
"Only softmax support batch dim in the first place currently."
assert tp_dim_index == 1, \
"Only softmax support head dim in the second place currently."
assert mask is None or mask.shape[tp_dim_index] == 1
if sharding_type is ShardingType.SINGLE:
outputs = _softmax(inputs, mask, scale_factor, softmax_type)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_softmax_sharding_meta(sharding_type,
inputs.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
if mask_ is not None:
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
# If mask is head broadcastable (heads == 1),
# then it equals to DP sharding.
mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP,
mask_.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
in_axes = (sharding_meta.in_axes[0], mask_in_axis)
outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, mask_))
outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0])
return outputs
@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(inputs, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type)
return output
def _softmax_fwd(inputs, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor)
else:
outputs = scaled_softmax_fwd(inputs, scale_factor)
return outputs, (outputs, mask)
def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs):
softmax_outputs, mask = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
else:
dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
return (dgrad, None)
_softmax.defvjp(_softmax_fwd, _softmax_bwd)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment