Unverified Commit bc9d57a3 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

Add TE/JAX high-level modules, unittests and examples (#54)



* add transformer module , unittests and examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update tests/jax/test_sharding.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/transformer.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint: disable=line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove pylint: disable=too-many-func-args
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Fix the wrong broadcasting dim to dropout masks when enable transpose_bs.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Enable 2xACC for WGRAD and DGRAD by default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename LayerNormMlpBlock as LayerNormMLP
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor to avoid line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename amax_history_size to amax_history_len
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* align dropout mask to TE/PyTorch as default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* enlarge atol for decoder unittests

Two decoder unittests can pass in old JAX container(e.g., 23.02)
but can't in latest container (devel).

1. The actual(-0.020264) and desired(-0.020386) are very close.
2. The TE kernels are not changed, the diff should come from
   new codegen behavior of XLA.

Thus, it is a common floating-point accumulated error.
Enlarge atol to avoid unittest failures.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Adding Amax History Support

1. hide amax update in custom_vjp
2. replace amax indexing with roll(using circular buffer)
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* move kernel_init to __post_init__
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor encoder examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove envvar regarding 2xACC
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove unused import
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5925d444
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with BF16 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
optimizer = optax.sgd(0.001, 0.9)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with FP8 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from cuda import cudart
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.common.recipe import DelayedScaling
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def gpu_has_fp8():
"""GPU arch has to support FP8"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
others = FP8Helper.update_fp8_metas(grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
if gpu_has_fp8() is False:
print("GPU doesn't support FP8")
return
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
optimizer = optax.sgd(0.001, 0.9)
with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(fp8_format=FP8Format.HYBRID)):
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
assert "fp8" in str(jax.make_jaxpr(jitted_train_step)(inputs, state, variables))
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
...@@ -6,3 +6,4 @@ set -xe ...@@ -6,3 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax pytest -Wignore -v $TE_PATH/tests/jax
pytest -Wignore -v $TE_PATH/examples/jax
...@@ -60,7 +60,7 @@ class TestFP8Dot: ...@@ -60,7 +60,7 @@ class TestFP8Dot:
def func(x, y): def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -68,7 +68,7 @@ class TestFP8Dot: ...@@ -68,7 +68,7 @@ class TestFP8Dot:
# y = input, matrix 2d (weight) # y = input, matrix 2d (weight)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))) return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
value_n_grad_func = value_and_grad(func, (0, 1)) value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
...@@ -84,13 +84,13 @@ class TestFP8Dot: ...@@ -84,13 +84,13 @@ class TestFP8Dot:
def func(x, y): def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type)) return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
value_n_grad_func = value_and_grad(func, (0, 1)) value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
...@@ -104,13 +104,13 @@ class TestFP8Dot: ...@@ -104,13 +104,13 @@ class TestFP8Dot:
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)) primitive_out = fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))
ref_out = jnp.dot(a, b) ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out)
...@@ -128,7 +128,7 @@ class TestFP8Dot: ...@@ -128,7 +128,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16) b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -136,12 +136,12 @@ class TestFP8Dot: ...@@ -136,12 +136,12 @@ class TestFP8Dot:
# calculate amax # calculate amax
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type) primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
# calculate scale by amax # calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta) fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type) primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
ref_out = jnp.dot(a, b) ref_out = jnp.dot(a, b)
ref_out = ref_out.astype(jnp.float32) ref_out = ref_out.astype(jnp.float32)
...@@ -158,13 +158,13 @@ class TestFP8Dot: ...@@ -158,13 +158,13 @@ class TestFP8Dot:
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))) return jnp.mean(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
def ref_func(x, y): def ref_func(x, y):
return jnp.mean(jnp.dot(x, y)) return jnp.mean(jnp.dot(x, y))
...@@ -193,7 +193,7 @@ class TestFP8Dot: ...@@ -193,7 +193,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16) b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -201,7 +201,7 @@ class TestFP8Dot: ...@@ -201,7 +201,7 @@ class TestFP8Dot:
def primitive_func(x, y, metas): def primitive_func(x, y, metas):
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas) fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type)) return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
def ref_func(x, y): def ref_func(x, y):
return jnp.sum(jnp.dot(x, y)) return jnp.sum(jnp.dot(x, y))
...@@ -232,13 +232,13 @@ class TestFP8Dot: ...@@ -232,13 +232,13 @@ class TestFP8Dot:
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None), ((2, 3), (0, 1)))) return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None), ((2, 3), (0, 1))))
def ref_func(x, y): def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ())))) return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
...@@ -266,7 +266,7 @@ class TestFP8Dot: ...@@ -266,7 +266,7 @@ class TestFP8Dot:
s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8) s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
...@@ -283,7 +283,6 @@ class TestFP8Dot: ...@@ -283,7 +283,6 @@ class TestFP8Dot:
ln_s, ln_s,
None, None,
"rmsnorm", "rmsnorm",
0,
*compute_type, *compute_type,
activations=activations)) activations=activations))
...@@ -305,7 +304,6 @@ class TestFP8Dot: ...@@ -305,7 +304,6 @@ class TestFP8Dot:
amax: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
amax_history_idx: int,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
epsilon=1e-6, epsilon=1e-6,
...@@ -323,7 +321,6 @@ class TestFP8Dot: ...@@ -323,7 +321,6 @@ class TestFP8Dot:
scale[:FP8Helper.NUM_META_PER_GEMM], scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM]) scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = fp8_dot(fp8_gemm_1_pkg, linear_1_out = fp8_dot(fp8_gemm_1_pkg,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -341,7 +338,6 @@ class TestFP8Dot: ...@@ -341,7 +338,6 @@ class TestFP8Dot:
scale[FP8Helper.NUM_META_PER_GEMM:], scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:]) scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = fp8_dot(fp8_gemm_2_pkg, output = fp8_dot(fp8_gemm_2_pkg,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -350,7 +346,7 @@ class TestFP8Dot: ...@@ -350,7 +346,7 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, metas): def ref_func(x, ln_s, y, z, metas):
return jnp.mean( return jnp.mean(
fp8_ln_mlp_py(x, ln_s, y, z, *metas, 0, *compute_type, activations=activations)) fp8_ln_mlp_py(x, ln_s, y, z, *metas, *compute_type, activations=activations))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3))) value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3))) value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3)))
......
...@@ -27,12 +27,12 @@ class TestFP8Helper(unittest.TestCase): ...@@ -27,12 +27,12 @@ class TestFP8Helper(unittest.TestCase):
margin = 5.0 margin = 5.0
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10 update_fp8meta_interval = 10
amax_history_size = 10 amax_history_len = 10
FP8Helper.initialize(margin=margin, FP8Helper.initialize(margin=margin,
fp8_format=fp8_format, fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval, update_fp8meta_interval=update_fp8meta_interval,
amax_history_size=amax_history_size) amax_history_len=amax_history_len)
self.assertEqual( self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}" FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
...@@ -46,15 +46,15 @@ class TestFP8Helper(unittest.TestCase): ...@@ -46,15 +46,15 @@ class TestFP8Helper(unittest.TestCase):
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be" "FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.") f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual( self.assertEqual(
FP8Helper.AMAX_HISTORY_SIZE, amax_history_size, FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
f"FP8Helper.AMAX_HISTORY_SIZE initialization failed, should be {amax_history_size}" f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_SIZE}.") f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
FP8Helper.finalize() FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_fp8_metas(self): def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_size=5) FP8Helper.initialize(margin=3.0, amax_history_len=3)
seed = 0 seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed)) key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
...@@ -72,13 +72,15 @@ class TestFP8Helper(unittest.TestCase): ...@@ -72,13 +72,15 @@ class TestFP8Helper(unittest.TestCase):
sf = np.where(np.isfinite(amax), sf, scale) sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf) return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_SIZE) meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta) fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape) fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1, jnp.ones(meta_shape)) fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1 fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape) fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2, jnp.ones(meta_shape)) fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2 fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({ state = flax.core.frozen_dict.FrozenDict({
...@@ -156,6 +158,9 @@ class TestFP8Functions(unittest.TestCase): ...@@ -156,6 +158,9 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self): def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8()) self.assertFalse(FP8Helper.enable_fp8())
...@@ -167,7 +172,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -167,7 +172,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
...@@ -176,16 +181,12 @@ class TestFP8Functions(unittest.TestCase): ...@@ -176,16 +181,12 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(amax_history_len=2)
with self.assertRaises(AssertionError):
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(amax_history_len=2)):
pass
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
...@@ -213,7 +214,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -213,7 +214,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst) self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state() self._check_defult_state()
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import tempfile
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import DenseGeneral
from transformer_engine.jax.fp8 import FP8Helper
from utils import is_fp8_supported
class MLPNN(nn.Module):
use_fp8_dense: bool = True
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=512)(x)
x = nn.relu(x)
features = [256, 256, 128]
for feature in features:
x = DenseGeneral(features=feature, transpose_batch_sequence=False,
dtype=jnp.bfloat16, use_bias=True)(x) \
if self.use_fp8_dense else nn.Dense(features=feature)(x)
x = nn.relu(x)
x = nn.Dense(features=10, use_bias=True)(x)
return x
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum, use_fp8_dense):
"""Creates initial `TrainState`."""
cnn = MLPNN(use_fp8_dense=use_fp8_dense)
variables = cnn.init(rng, jnp.ones([32, 28, 28, 1]))
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'],
tx=tx), variables
@partial(jax.jit, static_argnums=(3,))
def train_step(state, others, batch, use_fp8_dense):
"""Train for a single step."""
def loss_fn(collections):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(others)
state = state.apply_gradients(grads=grads['params'])
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics, grads
def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for idx, perm in enumerate(perms):
idx = idx + 1
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics, grads = train_step(state, variables, batch, use_fp8_dense)
updated_coll = {'params': state.params}
if use_fp8_dense:
updated_coll[FP8Helper.FP8_COLLECTION_NAME] \
= grads[FP8Helper.FP8_COLLECTION_NAME]
variables = FP8Helper.update_collections(updated_coll, variables)
batch_metrics.append(metrics)
if use_fp8_dense:
variables = FP8Helper.update_fp8_metas(variables)
return state, variables
@partial(jax.jit, static_argnums=(2,))
def eval_step(variables, batch, use_fp8_dense):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image'])
return compute_metrics(logits=logits, labels=batch['label'])
def eval_model(variables, test_ds, batch_size, use_fp8_dense):
test_ds_size = len(test_ds['image'])
steps_per_epoch = test_ds_size // batch_size
perms = np.arange(0, test_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
total_summary = {'correct': 0, 'loss': 0, 'total': 0}
for _, perm in enumerate(perms):
batch = {k: v[perm, ...] for k, v in test_ds.items()}
metrics = eval_step(variables, batch, use_fp8_dense)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
total_summary['correct'] += summary['accuracy'] * batch_size
total_summary['loss'] += summary['loss'] * batch_size
total_summary['total'] += batch_size
return total_summary['loss']/total_summary['total'], \
total_summary['correct']/total_summary['total']
class TestMnist(unittest.TestCase):
def setUp(self) -> None:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
self.learning_rate = 0.1
self.momentum = 0.9
self.num_epochs = 5
self.batch_size = 512
self.train_ds, self.test_ds = get_datasets()
self.margin = 0.0
self.num_fp8_layers = 3
self.fp8_meta_update_interval = 1
self.temp_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
self.fp8_ckpt_path = self.temp_file.name
self.seed = 0
acc_bfp16_ = self._mnist_baseline_runner()
acc_rtol = 0.005
self.target_accuracy = acc_bfp16_ * (1. - acc_rtol)
def tearDown(self):
self.temp_file.close()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_e4m3(self):
self._mnist_test_runner(FP8Format.E4M3)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_hybrid(self):
self._mnist_test_runner(FP8Format.HYBRID)
# Skip for now due to lack bf16 in TE.Format
# def test_mnist_bfloa16(self):
# self._mnist_test_runner(FP8Format.BFLOAT16)
def _mnist_baseline_runner(self):
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False)
del init_rng
_, accuracy = self._train_model(state, variables, self.num_epochs, rng, False)
return accuracy
def _mnist_test_runner(self, fp8_format):
FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format)
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True)
del init_rng
_, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True)
self.assertGreater(
test_accuracy, self.target_accuracy,
f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. "
f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}")
FP8Helper.finalize()
def _train_model(self, state, variables, epochs, rng, use_fp8_dense):
max_test_acc = 0.0
for _ in range(0, epochs):
rng, input_rng = jax.random.split(rng)
state, variables = train_epoch(state, variables, self.train_ds, self.batch_size,
input_rng, use_fp8_dense)
_, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense)
max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc
return state, max_test_acc
if __name__ == '__main__':
unittest.main()
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from jax.experimental import maps from jax.experimental import maps
from transformer_engine.jax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
...@@ -14,6 +15,7 @@ from transformer_engine.jax.sharding import global_shard_guard ...@@ -14,6 +15,7 @@ from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
from utils import is_devices_enough
def _get_sharding_resource(mesh_names, sharding_type): def _get_sharding_resource(mesh_names, sharding_type):
...@@ -47,14 +49,25 @@ SRS = [ ...@@ -47,14 +49,25 @@ SRS = [
] ]
def is_devices_enough(): class TestShardingSideAPI:
return len(jax.devices()) >= DEVICE_COUNT
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
@pytest.mark.parametrize('sr', SRS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
class TestGeneralFunc: class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_infer_major_sharding_type( def test_infer_major_sharding_type(
self, self,
mesh_shape, # pylint: disable=unused-argument mesh_shape, # pylint: disable=unused-argument
...@@ -94,7 +107,7 @@ class TestShardingMetaGenerator: ...@@ -94,7 +107,7 @@ class TestShardingMetaGenerator:
MODEL_AXIS_NAME = 'model' MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4): def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping): def stack_axes_meta(mapping):
...@@ -145,7 +158,7 @@ class TestShardingMetaGenerator: ...@@ -145,7 +158,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)), @pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))]) ((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1]) @pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a): def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1 model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1 model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
...@@ -240,7 +253,7 @@ class TestShardingMetaGenerator: ...@@ -240,7 +253,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)]) @pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)]) @pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1]) @pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape, def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim): batch_dim):
......
This diff is collapsed.
...@@ -2,5 +2,8 @@ ...@@ -2,5 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import RelativePositionBiases, TransformerLayer, TransformerLayerType
from .sharding import ShardingResource from .sharding import ShardingResource
...@@ -276,7 +276,7 @@ class CastTransposePrimitive(BasePrimitive): ...@@ -276,7 +276,7 @@ class CastTransposePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype), ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype),
ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype), ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -427,7 +427,7 @@ class GatedGeluFp8Primitive(BasePrimitive): ...@@ -427,7 +427,7 @@ class GatedGeluFp8Primitive(BasePrimitive):
batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen
out_types = [ out_types = [
ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype), ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -599,7 +599,7 @@ class DgatedGeluCastTransposePrimitive(BasePrimitive): ...@@ -599,7 +599,7 @@ class DgatedGeluCastTransposePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype), ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype),
ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype), ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, gelu_inputs, amax, scale, scale_inv] operands = [inputs, gelu_inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -915,7 +915,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -915,7 +915,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(x_shape, ir_out_dtype), ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_mu_dtype), ir.RankedTensorType.get((batch_size,), ir_mu_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [x, gamma, beta, amax, scale, scale_inv] operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ operand_shapes = [
...@@ -1196,7 +1196,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1196,7 +1196,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype), ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [x, gamma, amax, scale, scale_inv] operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -1369,7 +1369,7 @@ class QuantizePrimitive(BasePrimitive): ...@@ -1369,7 +1369,7 @@ class QuantizePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(ir_out_shape, ir_out_dtype), ir.RankedTensorType.get(ir_out_shape, ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
...@@ -21,7 +21,6 @@ jax.config.update('experimental_xmap_spmd_lowering_manual', True) ...@@ -21,7 +21,6 @@ jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
...@@ -45,7 +44,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -45,7 +44,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -77,7 +75,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -77,7 +75,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources]) [sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot, partial_fp8_dot = partial(_fp8_dot,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -93,18 +90,17 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -93,18 +90,17 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, amax_history_idx: int, fwd_dtype: TEDType, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
sharding_type: ShardingType, dp_axis_name: str, tp_axis_name: str): dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs, res, _ = _fp8_dot_fwd(inputs,
kernel, kernel,
fp8_maxs, fp8_maxs,
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -121,7 +117,6 @@ def _fp8_dot_fwd( ...@@ -121,7 +117,6 @@ def _fp8_dot_fwd(
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims,
...@@ -139,6 +134,8 @@ def _fp8_dot_fwd( ...@@ -139,6 +134,8 @@ def _fp8_dot_fwd(
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size)) inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1)) kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx] input_amax = amax[gemm_input_idx]
...@@ -170,7 +167,6 @@ def _fp8_dot_fwd( ...@@ -170,7 +167,6 @@ def _fp8_dot_fwd(
def _fp8_dot_bwd( def _fp8_dot_bwd(
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument contracting_dims, # pylint: disable=unused-argument
...@@ -202,9 +198,9 @@ def _fp8_dot_bwd( ...@@ -202,9 +198,9 @@ def _fp8_dot_bwd(
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0]) amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0]) amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name) wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
""" """
Helper module for fp8 meta management Helper module for fp8 meta management
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Union, Dict, List, Tuple from enum import Enum
from flax.core.frozen_dict import FrozenDict from typing import Dict, List, Optional, Tuple, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
...@@ -110,6 +112,12 @@ class FP8GemmPackage: ...@@ -110,6 +112,12 @@ class FP8GemmPackage:
return self._scale_inv return self._scale_inv
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
MAX = "max"
MOST_RECENT = "most_recent"
class FP8Helper: class FP8Helper:
""" """
FP8 helper to manage the FP8 meta FP8 helper to manage the FP8 meta
...@@ -120,7 +128,8 @@ class FP8Helper: ...@@ -120,7 +128,8 @@ class FP8Helper:
FWD_DTYPE: DType = DType.kFloat8E4M3 FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2 BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1 UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_SIZE: int = 1 AMAX_HISTORY_LEN: int = 1
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT
NUM_META_PER_GEMM: int = 3 NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0 INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1 KERNEL_META_IDX_PER_GEMM: int = 1
...@@ -130,12 +139,9 @@ class FP8Helper: ...@@ -130,12 +139,9 @@ class FP8Helper:
FP8_SCALE_NAME: str = "fp8_meta_scale" FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv" FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
FP8_MAX_NAME: str = "fp8_max" FP8_MAX_NAME: str = "fp8_max"
FP8_2X_ACC_FPROP_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_FPROP"
FP8_2X_ACC_DGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_DGRAD"
FP8_2X_ACC_WGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_WGRAD"
FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = True
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = True
@staticmethod @staticmethod
def enable_fp8(): def enable_fp8():
...@@ -148,7 +154,8 @@ class FP8Helper: ...@@ -148,7 +154,8 @@ class FP8Helper:
def initialize(margin: float = 0.0, def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID, fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1, update_fp8meta_interval: int = 1,
amax_history_size: int = 1) -> None: amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT) -> None:
""" """
Initialize the FP8 meta Initialize the FP8 meta
""" """
...@@ -158,13 +165,11 @@ class FP8Helper: ...@@ -158,13 +165,11 @@ class FP8Helper:
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT) _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_SIZE = amax_history_size FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.FP8_2X_ACC_FPROP = bool( FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
int(os.environ.get(FP8Helper.FP8_2X_ACC_FPROP_ENV_VAR_NAME, False))) FP8Helper.FP8_2X_ACC_FPROP = False
FP8Helper.FP8_2X_ACC_DGRAD = bool( FP8Helper.FP8_2X_ACC_DGRAD = True
int(os.environ.get(FP8Helper.FP8_2X_ACC_DGRAD_ENV_VAR_NAME, False))) FP8Helper.FP8_2X_ACC_WGRAD = True
FP8Helper.FP8_2X_ACC_WGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_WGRAD_ENV_VAR_NAME, False)))
@staticmethod @staticmethod
def finalize() -> None: def finalize() -> None:
...@@ -177,10 +182,19 @@ class FP8Helper: ...@@ -177,10 +182,19 @@ class FP8Helper:
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_SIZE = 1 FP8Helper.AMAX_HISTORY_LEN = 1
@staticmethod @staticmethod
def update_collections(new: Collection, original: Collection) -> None: def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
""" """
Update the collections Update the collections
""" """
...@@ -244,7 +258,10 @@ class FP8Helper: ...@@ -244,7 +258,10 @@ class FP8Helper:
fp8_scale_inv_idx = fp8_scale_idx + 1 fp8_scale_inv_idx = fp8_scale_idx + 1
fp8_max = fp8_meta_arrays[fp8_max_idx] fp8_max = fp8_meta_arrays[fp8_max_idx]
amax = fp8_meta_arrays[fp8_amax_idx] if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True)
else:
amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1]
scale = fp8_meta_arrays[fp8_scale_idx] scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
...@@ -262,7 +279,7 @@ class FP8Helper: ...@@ -262,7 +279,7 @@ class FP8Helper:
def fp8_autocast(enabled: bool = False, def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None: sharding_resource: Optional[ShardingResource] = None) -> None:
""" r"""
Context manager for FP8 usage. Context manager for FP8 usage.
.. code-block:: python .. code-block:: python
...@@ -284,15 +301,15 @@ def fp8_autocast(enabled: bool = False, ...@@ -284,15 +301,15 @@ def fp8_autocast(enabled: bool = False,
.. note:: .. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len=1` in recipe.DelayedScaling currently. Other parameters :attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set. in recipe.DelayedScaling would be ignored, even is set.
Parameters Parameters
---------- ----------
enabled: bool, default = False enabled: bool, default = False
whether or not to enable fp8 whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training. recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None sharding_resource: ShardingResource, defaule = None
specify the mesh axes for data and tensor parallelism to shard along. specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created. If set to None, then ShardingResource() would be created.
...@@ -300,19 +317,70 @@ def fp8_autocast(enabled: bool = False, ...@@ -300,19 +317,70 @@ def fp8_autocast(enabled: bool = False,
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = DelayedScaling() fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_history_len == 1, \
"It only support amax_history_len == 1 for now."
if sharding_resource is None: if sharding_resource is None:
sharding_resource = ShardingResource() sharding_resource = ShardingResource()
try: try:
with global_shard_guard(sharding_resource): with global_shard_guard(sharding_resource):
if enabled: if enabled:
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max':
amax_compute_algo = AmaxComputeAlgo.MAX
FP8Helper.initialize(margin=fp8_recipe.margin, FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format, fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval, update_fp8meta_interval=fp8_recipe.interval,
amax_history_size=fp8_recipe.amax_history_len) amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo)
yield yield
finally: finally:
FP8Helper.finalize() FP8Helper.finalize()
# Function Wrappers
def update_collections(new: Collection, original: Collection) -> Collection:
r"""
A helper to update Flax's Collection. Collection is a union type of dict and
Flax's FrozenDict.
Collection = [dict, FrozenDict]
Parameters
----------
new: Collection
A collection that includes new data.
original: Collection
The base collection.
Returns
-------
outputs : Collection
The updated collection.
"""
return FP8Helper.update_collections(new, original)
def update_fp8_metas(state: Collection) -> Collection:
r"""
Calculate new fp8 scales and its inverse via the followed formula
`exp` = floor(log2(`fp8_max` / `amax`)) - `margin`
`sf` = round(power(2, abs(exp)))
`sf` = `sf` if `amax` > 0.0, else original_scale
`sf` = `sf` if isfinite(`amax`), else original_scale)
`updated_scale` = `1/sf` if exp < 0, else `sf`
`updated_scale_inv` = `1/updated_scale`
Collection = [dict, FrozenDict]
Parameters
----------
state: Collection
A collection that includes FP8 metas.
Returns
-------
outputs : Collection
The collection with updated FP8 metas.
"""
return FP8Helper.update_fp8_metas(state)
...@@ -132,7 +132,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -132,7 +132,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
...@@ -167,7 +166,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -167,7 +166,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -213,7 +211,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -213,7 +211,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
partial_ln_fp8_dot = partial(_layernorm_fp8_dot, partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -233,7 +230,7 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -233,7 +230,7 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_fp8_dot(inputs: jnp.ndarray, def _layernorm_fp8_dot(inputs: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -243,7 +240,6 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray, ...@@ -243,7 +240,6 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]],
...@@ -252,9 +248,9 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray, ...@@ -252,9 +248,9 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
tp_axis_name: str, tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray: epsilon: float = 1e-6) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale, output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, amax_history_idx, fwd_dtype, scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
bwd_dtype, contracting_dims, sharding_type, dp_axis_name, contracting_dims, sharding_type, dp_axis_name, tp_axis_name,
tp_axis_name, epsilon) epsilon)
return output return output
...@@ -268,7 +264,6 @@ def _layernorm_fp8_dot_fwd( ...@@ -268,7 +264,6 @@ def _layernorm_fp8_dot_fwd(
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims,
...@@ -286,6 +281,8 @@ def _layernorm_fp8_dot_fwd( ...@@ -286,6 +281,8 @@ def _layernorm_fp8_dot_fwd(
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre) kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size assert input_contracting_size == kernel_contracting_size
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx] input_amax = amax[gemm_input_idx]
...@@ -337,7 +334,6 @@ def _layernorm_fp8_dot_fwd( ...@@ -337,7 +334,6 @@ def _layernorm_fp8_dot_fwd(
def _layernorm_fp8_dot_bwd( def _layernorm_fp8_dot_bwd(
layernorm_type, layernorm_type,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument contracting_dims, # pylint: disable=unused-argument
...@@ -386,9 +382,9 @@ def _layernorm_fp8_dot_bwd( ...@@ -386,9 +382,9 @@ def _layernorm_fp8_dot_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0]) amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0]) amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name) wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
...@@ -101,7 +101,6 @@ def fp8_ln_mlp( ...@@ -101,7 +101,6 @@ def fp8_ln_mlp(
ln_scale: jnp.ndarray, ln_scale: jnp.ndarray,
ln_bias: jnp.ndarray, ln_bias: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
epsilon: float = 1e-6, epsilon: float = 1e-6,
...@@ -130,8 +129,8 @@ def fp8_ln_mlp( ...@@ -130,8 +129,8 @@ def fp8_ln_mlp(
assert activations == ('gelu', 'linear') assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE: if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, amax_history_idx, activations, epsilon, fwd_dtype, scale_inv, layernorm_type, activations, epsilon, fwd_dtype, bwd_dtype,
bwd_dtype, contracting_dims, major_sharding_type, "", "") contracting_dims, major_sharding_type, "", "")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -177,7 +176,6 @@ def fp8_ln_mlp( ...@@ -177,7 +176,6 @@ def fp8_ln_mlp(
partial_fp8_mlp = partial(_fp8_mlp, partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
activations=activations, activations=activations,
epsilon=epsilon, epsilon=epsilon,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
...@@ -198,10 +196,10 @@ def fp8_ln_mlp( ...@@ -198,10 +196,10 @@ def fp8_ln_mlp(
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, amax_history_idx: int, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType, activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]], bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str): major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
...@@ -215,7 +213,6 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, ...@@ -215,7 +213,6 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx,
activations, activations,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -238,7 +235,6 @@ def _fp8_mlp_fwd( ...@@ -238,7 +235,6 @@ def _fp8_mlp_fwd(
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
activations, activations,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -266,6 +262,8 @@ def _fp8_mlp_fwd( ...@@ -266,6 +262,8 @@ def _fp8_mlp_fwd(
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1)) kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1)) kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm1_input_idx] input_amax = amax[gemm1_input_idx]
...@@ -335,7 +333,6 @@ def _fp8_mlp_fwd( ...@@ -335,7 +333,6 @@ def _fp8_mlp_fwd(
def _fp8_mlp_bwd( def _fp8_mlp_bwd(
layernorm_type, layernorm_type,
amax_history_idx,
activations, # pylint: disable=unused-argument activations, # pylint: disable=unused-argument
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -405,12 +402,12 @@ def _fp8_mlp_bwd( ...@@ -405,12 +402,12 @@ def _fp8_mlp_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
amax = amax.at[gemm1_input_idx, amax_history_idx].set(ln_out_amax[0]) amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, amax_history_idx].set(kernel_1_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, amax_history_idx].set(dgelu_amax[0]) amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, amax_history_idx].set(gated_gelu_amax[0]) amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, amax_history_idx].set(kernel_2_amax[0]) amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP): if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP):
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name) wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
......
This diff is collapsed.
...@@ -694,9 +694,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -694,9 +694,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name}] in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({ out_axes = in_axes[0]
dp_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape]) return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis},
input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self, def get_tp_col_sharding_meta(self,
input_shape: Tuple, input_shape: Tuple,
...@@ -764,9 +765,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -764,9 +765,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{tp_dim: tp_axis_name}] in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({ out_axes = in_axes[0]
tp_dim: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape]) return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis},
input_new_shapes, [input_shape])
@staticmethod @staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple, def _get_dptp_sharding_meta(input_shape: Tuple,
...@@ -794,7 +796,7 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -794,7 +796,7 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}] in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
out_axes = in_axes out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, { return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis, dp_axis_name: dp_mesh_axis,
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
from .cpp_extensions import scaled_softmax_fwd
from .cpp_extensions import scaled_softmax_bwd
from .cpp_extensions import scaled_masked_softmax_fwd
from .cpp_extensions import scaled_masked_softmax_bwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_fwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class SoftmaxType(Enum):
"""SoftmaxType."""
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int,
k_seqlen: int, dtype: jnp.dtype):
"""check softmax available"""
if softmax_type is SoftmaxType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype)
raise NotImplementedError
def softmax(inputs: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
tp_dim_index: int = 1):
"""
Softmax wrapper
"""
assert dp_dim_index == 0, \
"Only softmax support batch dim in the first place currently."
assert tp_dim_index == 1, \
"Only softmax support head dim in the second place currently."
assert mask is None or mask.shape[tp_dim_index] == 1
if sharding_type is ShardingType.SINGLE:
outputs = _softmax(inputs, mask, scale_factor, softmax_type)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_softmax_sharding_meta(sharding_type,
inputs.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
if mask_ is not None:
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
# If mask is head broadcastable (heads == 1),
# then it equals to DP sharding.
mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP,
mask_.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
in_axes = (sharding_meta.in_axes[0], mask_in_axis)
outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, mask_))
outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0])
return outputs
@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(inputs, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type)
return output
def _softmax_fwd(inputs, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor)
else:
outputs = scaled_softmax_fwd(inputs, scale_factor)
return outputs, (outputs, mask)
def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs):
softmax_outputs, mask = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
else:
dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
return (dgrad, None)
_softmax.defvjp(_softmax_fwd, _softmax_bwd)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment