Unverified Commit bc9d57a3 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

Add TE/JAX high-level modules, unittests and examples (#54)



* add transformer module , unittests and examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update tests/jax/test_sharding.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/transformer.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint: disable=line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove pylint: disable=too-many-func-args
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Fix the wrong broadcasting dim to dropout masks when enable transpose_bs.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Enable 2xACC for WGRAD and DGRAD by default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename LayerNormMlpBlock as LayerNormMLP
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor to avoid line-too-long
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename amax_history_size to amax_history_len
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* align dropout mask to TE/PyTorch as default
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* enlarge atol for decoder unittests

Two decoder unittests can pass in old JAX container(e.g., 23.02)
but can't in latest container (devel).

1. The actual(-0.020264) and desired(-0.020386) are very close.
2. The TE kernels are not changed, the diff should come from
   new codegen behavior of XLA.

Thus, it is a common floating-point accumulated error.
Enlarge atol to avoid unittest failures.
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Adding Amax History Support

1. hide amax update in custom_vjp
2. replace amax indexing with roll(using circular buffer)
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* move kernel_init to __post_init__
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor encoder examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update transformer_engine/jax/fp8.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove envvar regarding 2xACC
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* remove unused import
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarMing-Xu Huang <mingh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5925d444
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with BF16 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
optimizer = optax.sgd(0.001, 0.9)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with FP8 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from cuda import cudart
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.common.recipe import DelayedScaling
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def gpu_has_fp8():
"""GPU arch has to support FP8"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
others = FP8Helper.update_fp8_metas(grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
if gpu_has_fp8() is False:
print("GPU doesn't support FP8")
return
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
optimizer = optax.sgd(0.001, 0.9)
with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(fp8_format=FP8Format.HYBRID)):
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
assert "fp8" in str(jax.make_jaxpr(jitted_train_step)(inputs, state, variables))
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
...@@ -6,3 +6,4 @@ set -xe ...@@ -6,3 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax pytest -Wignore -v $TE_PATH/tests/jax
pytest -Wignore -v $TE_PATH/examples/jax
...@@ -60,7 +60,7 @@ class TestFP8Dot: ...@@ -60,7 +60,7 @@ class TestFP8Dot:
def func(x, y): def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -68,7 +68,7 @@ class TestFP8Dot: ...@@ -68,7 +68,7 @@ class TestFP8Dot:
# y = input, matrix 2d (weight) # y = input, matrix 2d (weight)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))) return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
value_n_grad_func = value_and_grad(func, (0, 1)) value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
...@@ -84,13 +84,13 @@ class TestFP8Dot: ...@@ -84,13 +84,13 @@ class TestFP8Dot:
def func(x, y): def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type)) return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
value_n_grad_func = value_and_grad(func, (0, 1)) value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile() value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
...@@ -104,13 +104,13 @@ class TestFP8Dot: ...@@ -104,13 +104,13 @@ class TestFP8Dot:
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)) primitive_out = fp8_dot(fp8_gemm_pkg, *_format2dtypes(None))
ref_out = jnp.dot(a, b) ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out) assert_allclose(primitive_out, ref_out)
...@@ -128,7 +128,7 @@ class TestFP8Dot: ...@@ -128,7 +128,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16) b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -136,12 +136,12 @@ class TestFP8Dot: ...@@ -136,12 +136,12 @@ class TestFP8Dot:
# calculate amax # calculate amax
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type) primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
# calculate scale by amax # calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta) fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta) fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type) primitive_out = fp8_dot(fp8_gemm_pkg, *compute_type)
ref_out = jnp.dot(a, b) ref_out = jnp.dot(a, b)
ref_out = ref_out.astype(jnp.float32) ref_out = ref_out.astype(jnp.float32)
...@@ -158,13 +158,13 @@ class TestFP8Dot: ...@@ -158,13 +158,13 @@ class TestFP8Dot:
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))) return jnp.mean(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None)))
def ref_func(x, y): def ref_func(x, y):
return jnp.mean(jnp.dot(x, y)) return jnp.mean(jnp.dot(x, y))
...@@ -193,7 +193,7 @@ class TestFP8Dot: ...@@ -193,7 +193,7 @@ class TestFP8Dot:
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16) b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
...@@ -201,7 +201,7 @@ class TestFP8Dot: ...@@ -201,7 +201,7 @@ class TestFP8Dot:
def primitive_func(x, y, metas): def primitive_func(x, y, metas):
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas) fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type)) return jnp.sum(fp8_dot(fp8_gemm_pkg, *compute_type))
def ref_func(x, y): def ref_func(x, y):
return jnp.sum(jnp.dot(x, y)) return jnp.sum(jnp.dot(x, y))
...@@ -232,13 +232,13 @@ class TestFP8Dot: ...@@ -232,13 +232,13 @@ class TestFP8Dot:
def primitive_func(x, y): def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None), ((2, 3), (0, 1)))) return jnp.sum(fp8_dot(fp8_gemm_pkg, *_format2dtypes(None), ((2, 3), (0, 1))))
def ref_func(x, y): def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ())))) return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
...@@ -266,7 +266,7 @@ class TestFP8Dot: ...@@ -266,7 +266,7 @@ class TestFP8Dot:
s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8) s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_SIZE), fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32) jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
...@@ -283,7 +283,6 @@ class TestFP8Dot: ...@@ -283,7 +283,6 @@ class TestFP8Dot:
ln_s, ln_s,
None, None,
"rmsnorm", "rmsnorm",
0,
*compute_type, *compute_type,
activations=activations)) activations=activations))
...@@ -305,7 +304,6 @@ class TestFP8Dot: ...@@ -305,7 +304,6 @@ class TestFP8Dot:
amax: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
amax_history_idx: int,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
epsilon=1e-6, epsilon=1e-6,
...@@ -323,7 +321,6 @@ class TestFP8Dot: ...@@ -323,7 +321,6 @@ class TestFP8Dot:
scale[:FP8Helper.NUM_META_PER_GEMM], scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM]) scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = fp8_dot(fp8_gemm_1_pkg, linear_1_out = fp8_dot(fp8_gemm_1_pkg,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -341,7 +338,6 @@ class TestFP8Dot: ...@@ -341,7 +338,6 @@ class TestFP8Dot:
scale[FP8Helper.NUM_META_PER_GEMM:], scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:]) scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = fp8_dot(fp8_gemm_2_pkg, output = fp8_dot(fp8_gemm_2_pkg,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -350,7 +346,7 @@ class TestFP8Dot: ...@@ -350,7 +346,7 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, metas): def ref_func(x, ln_s, y, z, metas):
return jnp.mean( return jnp.mean(
fp8_ln_mlp_py(x, ln_s, y, z, *metas, 0, *compute_type, activations=activations)) fp8_ln_mlp_py(x, ln_s, y, z, *metas, *compute_type, activations=activations))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3))) value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3))) value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3)))
......
...@@ -27,12 +27,12 @@ class TestFP8Helper(unittest.TestCase): ...@@ -27,12 +27,12 @@ class TestFP8Helper(unittest.TestCase):
margin = 5.0 margin = 5.0
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10 update_fp8meta_interval = 10
amax_history_size = 10 amax_history_len = 10
FP8Helper.initialize(margin=margin, FP8Helper.initialize(margin=margin,
fp8_format=fp8_format, fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval, update_fp8meta_interval=update_fp8meta_interval,
amax_history_size=amax_history_size) amax_history_len=amax_history_len)
self.assertEqual( self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}" FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
...@@ -46,15 +46,15 @@ class TestFP8Helper(unittest.TestCase): ...@@ -46,15 +46,15 @@ class TestFP8Helper(unittest.TestCase):
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be" "FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.") f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual( self.assertEqual(
FP8Helper.AMAX_HISTORY_SIZE, amax_history_size, FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
f"FP8Helper.AMAX_HISTORY_SIZE initialization failed, should be {amax_history_size}" f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_SIZE}.") f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
FP8Helper.finalize() FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_fp8_metas(self): def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_size=5) FP8Helper.initialize(margin=3.0, amax_history_len=3)
seed = 0 seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed)) key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
...@@ -72,13 +72,15 @@ class TestFP8Helper(unittest.TestCase): ...@@ -72,13 +72,15 @@ class TestFP8Helper(unittest.TestCase):
sf = np.where(np.isfinite(amax), sf, scale) sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf) return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_SIZE) meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta) fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape) fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1, jnp.ones(meta_shape)) fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1 fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape) fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2, jnp.ones(meta_shape)) fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2[:, 0:1],
jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2 fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({ state = flax.core.frozen_dict.FrozenDict({
...@@ -156,6 +158,9 @@ class TestFP8Functions(unittest.TestCase): ...@@ -156,6 +158,9 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self): def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8()) self.assertFalse(FP8Helper.enable_fp8())
...@@ -167,7 +172,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -167,7 +172,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
...@@ -176,16 +181,12 @@ class TestFP8Functions(unittest.TestCase): ...@@ -176,16 +181,12 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(amax_history_len=2)
with self.assertRaises(AssertionError):
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(amax_history_len=2)):
pass
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8') @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
...@@ -213,7 +214,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -213,7 +214,7 @@ class TestFP8Functions(unittest.TestCase):
self.assertEqual(FP8Helper.MARGIN, ds.margin) self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval) self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format) self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len) self.assertEqual(FP8Helper.AMAX_HISTORY_LEN, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst) self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state() self._check_defult_state()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from functools import partial
import flax
import jax
import jax.numpy as jnp
import pytest
from transformer_engine.common.recipe import Format
from transformer_engine.jax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper
from utils import assert_allclose, is_fp8_supported
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output)
def generate_test_rngs():
data_rng = jax.random.PRNGKey(0)
init_rng = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
apply_rng = {'dropout': jax.random.PRNGKey(3)}
return data_rng, init_rng, apply_rng
def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
others, params = variables.pop('params')
del variables
return layer, params, others
def compare_frozen_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test FrozenDict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f"FrozenDict on {key=}"
if isinstance(ref_fd[key], flax.core.frozen_dict.FrozenDict):
compare_frozen_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(ref_fd[key],
test_fd[key],
rtol=rtol,
atol=atol,
err_msg=f"{key=} is not close")
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)] # (seqlen, batch, emb_dim)
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
_KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_OUTPUT_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROP_PATH: 0.1
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_FUSE_QKV_PARAMS: False
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze()
if fuse_qkv:
unfreeze_target['attention']['qkv']['kernel'] = \
jnp.reshape(ref['attention']['qkv']['kernel'],
unfreeze_target['attention']['qkv']['kernel'].shape)
else:
unfreeze_target['attention']['query']['kernel'] = \
ref['attention']['query']['kernel']
unfreeze_target['attention']['key']['kernel'] = \
ref['attention']['key']['kernel']
unfreeze_target['attention']['value']['kernel'] = \
ref['attention']['value']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)
ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
apply_rng)
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'attention'
unfreeze_test_wgrad = test_wgrad.unfreeze()
if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
del unfreeze_test_wgrad['mlp']['scale']
if 'ln_bias' in unfreeze_test_wgrad['mlp']:
unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['mlp']['ln_bias']
del unfreeze_test_wgrad['mlp']['ln_bias']
unfreeze_test_wgrad['mlp']['wi'] = {}
unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
(unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
del unfreeze_test_wgrad['mlp']['wi_kernel']
unfreeze_test_wgrad['mlp']['wo'] = {}
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad)
compare_frozen_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
FP8Helper.finalize()
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
FP8Helper.finalize()
class TestDecoderLayer:
@staticmethod
def sync_params(ref, target, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
unfreeze_target = target.unfreeze()
if fuse_qkv:
unfreeze_target['self_attention']['qkv']['kernel'] = \
jnp.reshape(ref['self_attention']['qkv']['kernel'],
unfreeze_target['self_attention']['qkv']['kernel'].shape)
unfreeze_target['encoder_decoder_attention']['kv']['kernel'] = \
jnp.reshape(ref['encoder_decoder_attention']['kv']['kernel'],
unfreeze_target['encoder_decoder_attention']['kv']['kernel'].shape)
else:
unfreeze_target['self_attention']['query']['kernel'] = \
ref['self_attention']['query']['kernel']
unfreeze_target['self_attention']['key']['kernel'] = \
ref['self_attention']['key']['kernel']
unfreeze_target['self_attention']['value']['kernel'] = \
ref['self_attention']['value']['kernel']
unfreeze_target['encoder_decoder_attention']['query']['kernel'] = \
ref['encoder_decoder_attention']['query']['kernel']
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, flax.core.frozen_dict.FrozenDict(unfreeze_target)
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs()
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype))
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
ref_masks = (1 - causal_mask, 1 - padded_mask)
test_masks = (causal_mask, padded_mask)
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
data_rng, init_rng, apply_rng = generate_test_rngs()
batch = data_shape[1] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[0]
seqlen = data_shape[0] if _KEY_OF_TRANSPOSE_BS in attrs else data_shape[1]
inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype))
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
ref_masks = (1 - causal_mask, 1 - padded_mask)
test_masks = (causal_mask, padded_mask)
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
sequence_dim = 0
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params, attrs)
if FP8Helper.enable_fp8():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, fp8_meta_grad = tmp_grad[0].pop(FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)
ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
apply_rng)
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
def reorganize_test_wgrad(test_wgrad, attrs):
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True)
attn_name = 'self_attention'
unfreeze_test_wgrad = test_wgrad.unfreeze()
if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
unfreeze_test_wgrad['pre_self_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
if fuse_qkv:
unfreeze_test_wgrad[attn_name]['qkv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['qkv']['kernel'],
(unfreeze_test_wgrad[attn_name]['qkv']['kernel'].shape[0], -1))
attn_name = 'encoder_decoder_attention'
unfreeze_test_wgrad[attn_name]['kv']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name]['kv']['kernel'],
(unfreeze_test_wgrad[attn_name]['kv']['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
del unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
if 'ln_bias' in unfreeze_test_wgrad['encoder_decoder_attention']['query']:
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
del unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
del unfreeze_test_wgrad['mlp']['scale']
if 'ln_bias' in unfreeze_test_wgrad['mlp']:
unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['mlp']['ln_bias']
del unfreeze_test_wgrad['mlp']['ln_bias']
unfreeze_test_wgrad['mlp']['wi'] = {}
unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
(unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
del unfreeze_test_wgrad['mlp']['wi_kernel']
unfreeze_test_wgrad['mlp']['wo'] = {}
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return flax.core.frozen_dict.FrozenDict(unfreeze_test_wgrad)
compare_frozen_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
FP8Helper.finalize()
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
FP8Helper.finalize()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import tempfile
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import DenseGeneral
from transformer_engine.jax.fp8 import FP8Helper
from utils import is_fp8_supported
class MLPNN(nn.Module):
use_fp8_dense: bool = True
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=512)(x)
x = nn.relu(x)
features = [256, 256, 128]
for feature in features:
x = DenseGeneral(features=feature, transpose_batch_sequence=False,
dtype=jnp.bfloat16, use_bias=True)(x) \
if self.use_fp8_dense else nn.Dense(features=feature)(x)
x = nn.relu(x)
x = nn.Dense(features=10, use_bias=True)(x)
return x
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum, use_fp8_dense):
"""Creates initial `TrainState`."""
cnn = MLPNN(use_fp8_dense=use_fp8_dense)
variables = cnn.init(rng, jnp.ones([32, 28, 28, 1]))
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'],
tx=tx), variables
@partial(jax.jit, static_argnums=(3,))
def train_step(state, others, batch, use_fp8_dense):
"""Train for a single step."""
def loss_fn(collections):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(others)
state = state.apply_gradients(grads=grads['params'])
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics, grads
def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for idx, perm in enumerate(perms):
idx = idx + 1
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics, grads = train_step(state, variables, batch, use_fp8_dense)
updated_coll = {'params': state.params}
if use_fp8_dense:
updated_coll[FP8Helper.FP8_COLLECTION_NAME] \
= grads[FP8Helper.FP8_COLLECTION_NAME]
variables = FP8Helper.update_collections(updated_coll, variables)
batch_metrics.append(metrics)
if use_fp8_dense:
variables = FP8Helper.update_fp8_metas(variables)
return state, variables
@partial(jax.jit, static_argnums=(2,))
def eval_step(variables, batch, use_fp8_dense):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image'])
return compute_metrics(logits=logits, labels=batch['label'])
def eval_model(variables, test_ds, batch_size, use_fp8_dense):
test_ds_size = len(test_ds['image'])
steps_per_epoch = test_ds_size // batch_size
perms = np.arange(0, test_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
total_summary = {'correct': 0, 'loss': 0, 'total': 0}
for _, perm in enumerate(perms):
batch = {k: v[perm, ...] for k, v in test_ds.items()}
metrics = eval_step(variables, batch, use_fp8_dense)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
total_summary['correct'] += summary['accuracy'] * batch_size
total_summary['loss'] += summary['loss'] * batch_size
total_summary['total'] += batch_size
return total_summary['loss']/total_summary['total'], \
total_summary['correct']/total_summary['total']
class TestMnist(unittest.TestCase):
def setUp(self) -> None:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
self.learning_rate = 0.1
self.momentum = 0.9
self.num_epochs = 5
self.batch_size = 512
self.train_ds, self.test_ds = get_datasets()
self.margin = 0.0
self.num_fp8_layers = 3
self.fp8_meta_update_interval = 1
self.temp_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
self.fp8_ckpt_path = self.temp_file.name
self.seed = 0
acc_bfp16_ = self._mnist_baseline_runner()
acc_rtol = 0.005
self.target_accuracy = acc_bfp16_ * (1. - acc_rtol)
def tearDown(self):
self.temp_file.close()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_e4m3(self):
self._mnist_test_runner(FP8Format.E4M3)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_hybrid(self):
self._mnist_test_runner(FP8Format.HYBRID)
# Skip for now due to lack bf16 in TE.Format
# def test_mnist_bfloa16(self):
# self._mnist_test_runner(FP8Format.BFLOAT16)
def _mnist_baseline_runner(self):
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False)
del init_rng
_, accuracy = self._train_model(state, variables, self.num_epochs, rng, False)
return accuracy
def _mnist_test_runner(self, fp8_format):
FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format)
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True)
del init_rng
_, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True)
self.assertGreater(
test_accuracy, self.target_accuracy,
f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. "
f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}")
FP8Helper.finalize()
def _train_model(self, state, variables, epochs, rng, use_fp8_dense):
max_test_acc = 0.0
for _ in range(0, epochs):
rng, input_rng = jax.random.split(rng)
state, variables = train_epoch(state, variables, self.train_ds, self.batch_size,
input_rng, use_fp8_dense)
_, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense)
max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc
return state, max_test_acc
if __name__ == '__main__':
unittest.main()
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import pytest import pytest
from jax.experimental import maps from jax.experimental import maps
from transformer_engine.jax import extend_logical_axis_rules
from transformer_engine.jax.sharding import get_dot_sharding_meta from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
...@@ -14,6 +15,7 @@ from transformer_engine.jax.sharding import global_shard_guard ...@@ -14,6 +15,7 @@ from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
from utils import is_devices_enough
def _get_sharding_resource(mesh_names, sharding_type): def _get_sharding_resource(mesh_names, sharding_type):
...@@ -47,14 +49,25 @@ SRS = [ ...@@ -47,14 +49,25 @@ SRS = [
] ]
def is_devices_enough(): class TestShardingSideAPI:
return len(jax.devices()) >= DEVICE_COUNT
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
@pytest.mark.parametrize('sr', SRS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
class TestGeneralFunc: class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_infer_major_sharding_type( def test_infer_major_sharding_type(
self, self,
mesh_shape, # pylint: disable=unused-argument mesh_shape, # pylint: disable=unused-argument
...@@ -94,7 +107,7 @@ class TestShardingMetaGenerator: ...@@ -94,7 +107,7 @@ class TestShardingMetaGenerator:
MODEL_AXIS_NAME = 'model' MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG) @pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4): def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping): def stack_axes_meta(mapping):
...@@ -145,7 +158,7 @@ class TestShardingMetaGenerator: ...@@ -145,7 +158,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)), @pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))]) ((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1]) @pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a): def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1 model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1 model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
...@@ -240,7 +253,7 @@ class TestShardingMetaGenerator: ...@@ -240,7 +253,7 @@ class TestShardingMetaGenerator:
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)]) @pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)]) @pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1]) @pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough') @pytest.mark.skipif(not is_devices_enough(DEVICE_COUNT), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape, def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim): batch_dim):
......
...@@ -2,12 +2,19 @@ ...@@ -2,12 +2,19 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Any, Callable, Tuple, Union import functools
import operator
from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from cuda import cudart from cuda import cudart
from jax import lax from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -32,6 +39,969 @@ def is_fp8_supported(): ...@@ -32,6 +39,969 @@ def is_fp8_supported():
return sm_major >= 9 return sm_major >= 9
def is_devices_enough(required):
return len(jax.devices()) >= required
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
drop_path_shape = list(range(0, len(shape)))
drop_path_shape.pop(batch_dim)
return drop_path_shape
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
def _canonicalize_tuple(x):
if isinstance(x, Iterable):
return tuple(x)
return (x,)
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks.
Args:
*masks: set of attention mask arguments to combine, some can be None.
dtype: final mask dtype
Returns:
Combined mask, reduced by logical and, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = jnp.logical_and(mask, other_mask)
return mask.astype(dtype)
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
Args:
*masks: set of attention bias arguments to combine, some can be None.
Returns:
Combined mask, reduced by summation, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
return mask
def dot_product_attention(query: Array,
key: Array,
value: Array,
transpose_batch_sequence: bool,
bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: DType = jnp.float32,
float32_logits: bool = False):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Args:
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_heads, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: float32)
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, q_length, kv_length]
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(dtype)
# Apply attention dropout.
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
# Take the linear combination of `value`.
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes and FP8 support.
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
"""
features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,),
self.dtype,
axes=self.bias_axes)
else:
bias = None
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block.
Attributes:
intermediate_dim: Shared dimension of hidden layers.
activations: Type of activations for each layer. Each element is either
'linear', a string function name in flax.linen, or a function.
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer.
"""
transpose_batch_sequence: bool
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ('relu',)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
dtype: Any = jnp.float32
fuse_wi: bool = False
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self, inputs, deterministic: bool = False):
"""Applies Transformer MlpBlock module."""
# Iterate over specified MLP input activation functions.
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
activations = []
if self.fuse_wi:
dense_name = 'wi'
num_activations = len(self.activations)
x = DenseGeneral(self.intermediate_dim * num_activations,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
name=dense_name)(inputs)
x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
else:
for idx, act_fn in enumerate(self.activations):
dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
x = DenseGeneral(self.intermediate_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
name=dense_name)(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
x, deterministic=deterministic) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp'))
output = DenseGeneral(inputs.shape[-1],
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('mlp', 'embed'),
name='wo')(x)
return output
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
head_dim: dimension of each head.
dtype: the dtype of the computation.
dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
num_heads: int
head_dim: int
transpose_batch_sequence: bool
dtype: DType = jnp.float32
dropout_rate: float = 0.
kernel_init: Initializer = None
float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
fuse_qkv: bool = True
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
super().__post_init__()
@nn.compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False) -> Array:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode` argument. For decoding, this method is called twice,
first to initialize the cache and then for an actual decoding process. The
two calls are differentiated by the presence of 'cached_key' in the variable
dict. In the cache initialization stage, the cache variables are initialized
as zeros and will be filled in the subsequent decoding process.
In the cache initialization call, `inputs_q` has a shape [batch, length,
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
incremental decoding stage, query, key and value all have the shape [batch,
1, qkv_features] corresponding to a single step.
Args:
inputs_q: input queries of shape `[batch, q_length, q_features]`.
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
decode: Whether to prepare and use an autoregressive cache.
deterministic: Disables dropout if set to True.
Returns:
output of shape `[batch, length, q_features]`.
"""
projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / ( # pylint: disable=unnecessary-lambda-assignment
depth_scaling if self.scaled_query_init else 1.0)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
def qkv_init(key, shape, dtype):
assert shape[-1] % 3 == 0
q_shape = (shape[0], shape[1] // 3)
k_shape = (shape[0], shape[1] // 3)
v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype) # pylint: disable=too-many-function-args
v_kernel = self.kernel_init(key, v_shape, dtype) # pylint: disable=too-many-function-args
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
if self.fuse_qkv:
if inputs_q is inputs_kv:
qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'),
kernel_init=qkv_init,
name='qkv',
dtype=self.dtype)(inputs_kv)
query, key, value = jnp.split(
qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1)
if self.scale_attn_logits:
query = query / depth_scaling
else:
query = projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init,
name='kv',
dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1)
else:
query = projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query,
('length', 'batch', 'heads', 'kv'))
key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv'))
value = nn_partitioning.with_sharding_constraint(value,
('length', 'batch', 'heads', 'kv'))
else:
query = nn_partitioning.with_sharding_constraint(query,
('batch', 'length', 'heads', 'kv'))
key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
value = nn_partitioning.with_sharding_constraint(value,
('batch', 'length', 'heads', 'kv'))
if decode:
# Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
# The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) # pylint: disable=unnecessary-lambda-assignment
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
value.dtype)
cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
# Sanity shape check of cached key against input query.
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
f"expected query shape {expected_shape} instead got {query.shape}.")
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
# In order to update the key, value caches with the current key and
# value, we move the length axis to the back, similar to what we did for
# the cached ones above.
# Note these are currently the key and value of a single position, since
# we feed one position at a time.
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
# Update key, value caches with our new 1d spatial slices.
# We implement an efficient scatter into the cache via one-hot
# broadcast and addition.
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
# Move the keys and values back to their original shapes.
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
# Causal mask for cached decoder self-attention: our single query
# position should only attend to those key positions that have already
# been generated and cached, not the remaining zero elements.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(length) <= cur_index,
# (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one
# index.
# The same mask is applied to all batch elements and heads.
(batch, 1, 1, length)))
# Grab the correct relative attention bias during decoding. This is
# only required during single step decoding.
if bias is not None:
# The bias is a full attention matrix, but during decoding we only
# have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :].
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(mask > 0,
jnp.full(mask.shape, 0.).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype))
else:
attention_bias = None
# Add provided bias term (e.g. relative position embedding).
if bias is not None:
attention_bias = combine_biases(attention_bias, bias)
dropout_rng = None
if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng('dropout')
# Apply attention.
x = dot_product_attention(query,
key,
value,
transpose_batch_sequence=self.transpose_batch_sequence,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv'))
else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv'))
# Back to the original inputs dimensions.
out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'),
dtype=self.dtype,
name='out')(x)
return out
class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data."""
epsilon: float = 1e-6
dtype: Any = jnp.float32
layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones
bias_init: Initializer = nn.initializers.zeros
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input."""
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
scale = nn_partitioning.param_with_axes('scale',
self.scale_init, (features,),
jnp.float32,
axes=('embed',))
scale = jnp.asarray(scale, self.dtype)
if self.layernorm_type == 'layernorm':
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes('ln_bias',
self.bias_init, (features,),
jnp.float32,
axes=('embed',))
bias = jnp.asarray(bias, self.dtype)
y = jnp.asarray(y, self.dtype)
z = y * scale + bias
else:
assert self.layernorm_type == 'rmsnorm'
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
z = y * scale
return z
class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
Attributes:
num_buckets: Number of buckets to bucket distances between key and query
positions into.
max_distance: Maximum distance before everything is lumped into the last
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table.
"""
num_buckets: int
max_distance: int
num_heads: int
dtype: Any
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been
trained on.
Args:
relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).astype(np.int32) * num_buckets
n = np.abs(n)
else:
n = np.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large)
return ret
@nn.compact
def __call__(self, qlen, klen, bidirectional=True):
"""Produce relative position embedding attention biases.
Args:
qlen: attention query length.
klen: attention key length.
bidirectional: whether to allow positive memory-query relative position
embeddings.
Returns:
output: `(1, len, q_len, k_len)` attention bias
"""
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(relative_position,
bidirectional=bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance)
relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding',
self.embedding_init, (self.num_heads, self.num_buckets),
jnp.float32,
axes=('heads', 'relpos_buckets'))
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot
# array from rp_bucket and use it to perform the gather-equivalent via a
# contraction, i.e.:
# (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
# This is equivalent to relative_attention_bias[:, rp_bucket]
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
# --> shape (qlen, klen, num_heads)
values = lax.dot_general(
relative_attention_bias,
rp_bucket_one_hot,
(
((1,), (0,)), # rhs, lhs contracting dims
((), ()))) # no batched dims
# Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...]
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
relative_embedding: nn.Module = None
num_heads: int = 8
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm'
output_layernorm: bool = False
drop_path: float = 0.0
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
# Attention block.
residual = inputs
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name="pre_attention_layer_norm")(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
else:
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_heads,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
name='attention')(x,
x,
encoder_mask,
encoder_bias,
deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
x = x + residual
# MLP block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name='pre_mlp_layer_norm')(x)
if self.apply_residual_connection_post_layernorm:
residual = y
# [batch, length, emb_dim] -> [batch, length, emb_dim]
y = MlpBlock(
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.dropout_rate,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(y, deterministic=deterministic)
y = y + residual
if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name="output_layer_norm")(y)
return y
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
relative_embedding: nn.Module = None
num_heads: int = 8
head_dim: int = 64
dropout_rate: float = 0.1
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
layernorm_type: str = 'layernorm'
drop_path: float = 0.0
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
@nn.compact
def __call__(self,
inputs,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
deterministic=False,
decode=False,
max_decode_length=None):
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
residual = inputs
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
else:
x = inputs
# Self-attention block
x = MultiHeadAttention(num_heads=self.num_heads,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x,
x,
decoder_mask,
decoder_bias,
deterministic=deterministic,
decode=decode)
x = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
x = x + residual
# Encoder-Decoder block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x)
if self.apply_residual_connection_post_layernorm:
residual = y
y = MultiHeadAttention(num_heads=self.num_heads,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y,
encoded,
encoder_decoder_mask,
deterministic=deterministic)
y = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
y = y + residual
# MLP block.
residual = y
z = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name='pre_mlp_layer_norm')(y)
if self.apply_residual_connection_post_layernorm:
residual = z
z = MlpBlock(
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.dropout_rate,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(z, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
z = z + residual
if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type,
dtype=self.dtype,
name="output_layer_norm")(z)
return z
def assert_allclose(actual, def assert_allclose(actual,
desired, desired,
rtol=1e-05, rtol=1e-05,
......
...@@ -2,5 +2,8 @@ ...@@ -2,5 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast from .fp8 import fp8_autocast, update_collections, update_fp8_metas
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import RelativePositionBiases, TransformerLayer, TransformerLayerType
from .sharding import ShardingResource from .sharding import ShardingResource
...@@ -276,7 +276,7 @@ class CastTransposePrimitive(BasePrimitive): ...@@ -276,7 +276,7 @@ class CastTransposePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype), ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype),
ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype), ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -427,7 +427,7 @@ class GatedGeluFp8Primitive(BasePrimitive): ...@@ -427,7 +427,7 @@ class GatedGeluFp8Primitive(BasePrimitive):
batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen
out_types = [ out_types = [
ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype), ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -599,7 +599,7 @@ class DgatedGeluCastTransposePrimitive(BasePrimitive): ...@@ -599,7 +599,7 @@ class DgatedGeluCastTransposePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype), ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype),
ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype), ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, gelu_inputs, amax, scale, scale_inv] operands = [inputs, gelu_inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -915,7 +915,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -915,7 +915,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(x_shape, ir_out_dtype), ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_mu_dtype), ir.RankedTensorType.get((batch_size,), ir_mu_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [x, gamma, beta, amax, scale, scale_inv] operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ operand_shapes = [
...@@ -1196,7 +1196,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1196,7 +1196,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype), ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype), ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [x, gamma, amax, scale, scale_inv] operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
...@@ -1369,7 +1369,7 @@ class QuantizePrimitive(BasePrimitive): ...@@ -1369,7 +1369,7 @@ class QuantizePrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(ir_out_shape, ir_out_dtype), ir.RankedTensorType.get(ir_out_shape, ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
] ]
operands = [inputs, amax, scale, scale_inv] operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
...@@ -21,7 +21,6 @@ jax.config.update('experimental_xmap_spmd_lowering_manual', True) ...@@ -21,7 +21,6 @@ jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
...@@ -45,7 +44,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -45,7 +44,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -77,7 +75,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -77,7 +75,6 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources]) [sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot, partial_fp8_dot = partial(_fp8_dot,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -93,18 +90,17 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -93,18 +90,17 @@ def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) @partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, amax_history_idx: int, fwd_dtype: TEDType, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
sharding_type: ShardingType, dp_axis_name: str, tp_axis_name: str): dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs, res, _ = _fp8_dot_fwd(inputs,
kernel, kernel,
fp8_maxs, fp8_maxs,
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -121,7 +117,6 @@ def _fp8_dot_fwd( ...@@ -121,7 +117,6 @@ def _fp8_dot_fwd(
amax, amax,
scale, scale,
scale_inv, scale_inv,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims,
...@@ -139,6 +134,8 @@ def _fp8_dot_fwd( ...@@ -139,6 +134,8 @@ def _fp8_dot_fwd(
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size)) inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1)) kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx] input_amax = amax[gemm_input_idx]
...@@ -170,7 +167,6 @@ def _fp8_dot_fwd( ...@@ -170,7 +167,6 @@ def _fp8_dot_fwd(
def _fp8_dot_bwd( def _fp8_dot_bwd(
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument contracting_dims, # pylint: disable=unused-argument
...@@ -202,9 +198,9 @@ def _fp8_dot_bwd( ...@@ -202,9 +198,9 @@ def _fp8_dot_bwd(
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv, dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD) bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0]) amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0]) amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name) wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
""" """
Helper module for fp8 meta management Helper module for fp8 meta management
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Union, Dict, List, Tuple from enum import Enum
from flax.core.frozen_dict import FrozenDict from typing import Dict, List, Optional, Tuple, Union
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType from transformer_engine_jax import DType
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
...@@ -110,6 +112,12 @@ class FP8GemmPackage: ...@@ -110,6 +112,12 @@ class FP8GemmPackage:
return self._scale_inv return self._scale_inv
class AmaxComputeAlgo(Enum):
"""AmaxComputeAlgo."""
MAX = "max"
MOST_RECENT = "most_recent"
class FP8Helper: class FP8Helper:
""" """
FP8 helper to manage the FP8 meta FP8 helper to manage the FP8 meta
...@@ -120,7 +128,8 @@ class FP8Helper: ...@@ -120,7 +128,8 @@ class FP8Helper:
FWD_DTYPE: DType = DType.kFloat8E4M3 FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2 BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1 UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_SIZE: int = 1 AMAX_HISTORY_LEN: int = 1
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT
NUM_META_PER_GEMM: int = 3 NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0 INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1 KERNEL_META_IDX_PER_GEMM: int = 1
...@@ -130,12 +139,9 @@ class FP8Helper: ...@@ -130,12 +139,9 @@ class FP8Helper:
FP8_SCALE_NAME: str = "fp8_meta_scale" FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv" FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
FP8_MAX_NAME: str = "fp8_max" FP8_MAX_NAME: str = "fp8_max"
FP8_2X_ACC_FPROP_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_FPROP"
FP8_2X_ACC_DGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_DGRAD"
FP8_2X_ACC_WGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_WGRAD"
FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = True
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = True
@staticmethod @staticmethod
def enable_fp8(): def enable_fp8():
...@@ -148,7 +154,8 @@ class FP8Helper: ...@@ -148,7 +154,8 @@ class FP8Helper:
def initialize(margin: float = 0.0, def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID, fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1, update_fp8meta_interval: int = 1,
amax_history_size: int = 1) -> None: amax_history_len: int = 1,
amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MOST_RECENT) -> None:
""" """
Initialize the FP8 meta Initialize the FP8 meta
""" """
...@@ -158,13 +165,11 @@ class FP8Helper: ...@@ -158,13 +165,11 @@ class FP8Helper:
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \ FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT) _format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_SIZE = amax_history_size FP8Helper.AMAX_HISTORY_LEN = amax_history_len
FP8Helper.FP8_2X_ACC_FPROP = bool( FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo
int(os.environ.get(FP8Helper.FP8_2X_ACC_FPROP_ENV_VAR_NAME, False))) FP8Helper.FP8_2X_ACC_FPROP = False
FP8Helper.FP8_2X_ACC_DGRAD = bool( FP8Helper.FP8_2X_ACC_DGRAD = True
int(os.environ.get(FP8Helper.FP8_2X_ACC_DGRAD_ENV_VAR_NAME, False))) FP8Helper.FP8_2X_ACC_WGRAD = True
FP8Helper.FP8_2X_ACC_WGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_WGRAD_ENV_VAR_NAME, False)))
@staticmethod @staticmethod
def finalize() -> None: def finalize() -> None:
...@@ -177,10 +182,19 @@ class FP8Helper: ...@@ -177,10 +182,19 @@ class FP8Helper:
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3 FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2 FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1 FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_SIZE = 1 FP8Helper.AMAX_HISTORY_LEN = 1
@staticmethod @staticmethod
def update_collections(new: Collection, original: Collection) -> None: def update_amax_history(amax_buffers: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax_buffers = jnp.roll(amax_buffers, -1, 1)
updated_amax_buffers.at[:, 0].set(0)
return updated_amax_buffers
@staticmethod
def update_collections(new: Collection, original: Collection) -> Collection:
""" """
Update the collections Update the collections
""" """
...@@ -244,7 +258,10 @@ class FP8Helper: ...@@ -244,7 +258,10 @@ class FP8Helper:
fp8_scale_inv_idx = fp8_scale_idx + 1 fp8_scale_inv_idx = fp8_scale_idx + 1
fp8_max = fp8_meta_arrays[fp8_max_idx] fp8_max = fp8_meta_arrays[fp8_max_idx]
amax = fp8_meta_arrays[fp8_amax_idx] if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=1, keepdims=True)
else:
amax = fp8_meta_arrays[fp8_amax_idx][:, 0:1]
scale = fp8_meta_arrays[fp8_scale_idx] scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
...@@ -262,7 +279,7 @@ class FP8Helper: ...@@ -262,7 +279,7 @@ class FP8Helper:
def fp8_autocast(enabled: bool = False, def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None: sharding_resource: Optional[ShardingResource] = None) -> None:
""" r"""
Context manager for FP8 usage. Context manager for FP8 usage.
.. code-block:: python .. code-block:: python
...@@ -284,15 +301,15 @@ def fp8_autocast(enabled: bool = False, ...@@ -284,15 +301,15 @@ def fp8_autocast(enabled: bool = False,
.. note:: .. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len=1` in recipe.DelayedScaling currently. Other parameters :attr:`amax_history_len` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set. in recipe.DelayedScaling would be ignored, even is set.
Parameters Parameters
---------- ----------
enabled: bool, default = False enabled: bool, default = False
whether or not to enable fp8 whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training. recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None sharding_resource: ShardingResource, defaule = None
specify the mesh axes for data and tensor parallelism to shard along. specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created. If set to None, then ShardingResource() would be created.
...@@ -300,19 +317,70 @@ def fp8_autocast(enabled: bool = False, ...@@ -300,19 +317,70 @@ def fp8_autocast(enabled: bool = False,
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = DelayedScaling() fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_history_len == 1, \
"It only support amax_history_len == 1 for now."
if sharding_resource is None: if sharding_resource is None:
sharding_resource = ShardingResource() sharding_resource = ShardingResource()
try: try:
with global_shard_guard(sharding_resource): with global_shard_guard(sharding_resource):
if enabled: if enabled:
amax_compute_algo = AmaxComputeAlgo.MOST_RECENT
if fp8_recipe.amax_compute_algo == 'max':
amax_compute_algo = AmaxComputeAlgo.MAX
FP8Helper.initialize(margin=fp8_recipe.margin, FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format, fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval, update_fp8meta_interval=fp8_recipe.interval,
amax_history_size=fp8_recipe.amax_history_len) amax_history_len=fp8_recipe.amax_history_len,
amax_compute_algo=amax_compute_algo)
yield yield
finally: finally:
FP8Helper.finalize() FP8Helper.finalize()
# Function Wrappers
def update_collections(new: Collection, original: Collection) -> Collection:
r"""
A helper to update Flax's Collection. Collection is a union type of dict and
Flax's FrozenDict.
Collection = [dict, FrozenDict]
Parameters
----------
new: Collection
A collection that includes new data.
original: Collection
The base collection.
Returns
-------
outputs : Collection
The updated collection.
"""
return FP8Helper.update_collections(new, original)
def update_fp8_metas(state: Collection) -> Collection:
r"""
Calculate new fp8 scales and its inverse via the followed formula
`exp` = floor(log2(`fp8_max` / `amax`)) - `margin`
`sf` = round(power(2, abs(exp)))
`sf` = `sf` if `amax` > 0.0, else original_scale
`sf` = `sf` if isfinite(`amax`), else original_scale)
`updated_scale` = `1/sf` if exp < 0, else `sf`
`updated_scale_inv` = `1/updated_scale`
Collection = [dict, FrozenDict]
Parameters
----------
state: Collection
A collection that includes FP8 metas.
Returns
-------
outputs : Collection
The collection with updated FP8 metas.
"""
return FP8Helper.update_fp8_metas(state)
...@@ -132,7 +132,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -132,7 +132,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
...@@ -167,7 +166,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -167,7 +166,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
...@@ -213,7 +211,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -213,7 +211,6 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
partial_ln_fp8_dot = partial(_layernorm_fp8_dot, partial_ln_fp8_dot = partial(_layernorm_fp8_dot,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
...@@ -233,7 +230,7 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -233,7 +230,7 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_fp8_dot(inputs: jnp.ndarray, def _layernorm_fp8_dot(inputs: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -243,7 +240,6 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray, ...@@ -243,7 +240,6 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
scale: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, scale_inv: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]],
...@@ -252,9 +248,9 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray, ...@@ -252,9 +248,9 @@ def _layernorm_fp8_dot(inputs: jnp.ndarray,
tp_axis_name: str, tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray: epsilon: float = 1e-6) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale, output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, amax_history_idx, fwd_dtype, scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
bwd_dtype, contracting_dims, sharding_type, dp_axis_name, contracting_dims, sharding_type, dp_axis_name, tp_axis_name,
tp_axis_name, epsilon) epsilon)
return output return output
...@@ -268,7 +264,6 @@ def _layernorm_fp8_dot_fwd( ...@@ -268,7 +264,6 @@ def _layernorm_fp8_dot_fwd(
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims,
...@@ -286,6 +281,8 @@ def _layernorm_fp8_dot_fwd( ...@@ -286,6 +281,8 @@ def _layernorm_fp8_dot_fwd(
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre) kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size assert input_contracting_size == kernel_contracting_size
amax = FP8Helper.update_amax_history(amax)
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx] input_amax = amax[gemm_input_idx]
...@@ -337,7 +334,6 @@ def _layernorm_fp8_dot_fwd( ...@@ -337,7 +334,6 @@ def _layernorm_fp8_dot_fwd(
def _layernorm_fp8_dot_bwd( def _layernorm_fp8_dot_bwd(
layernorm_type, layernorm_type,
amax_history_idx,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument contracting_dims, # pylint: disable=unused-argument
...@@ -386,9 +382,9 @@ def _layernorm_fp8_dot_bwd( ...@@ -386,9 +382,9 @@ def _layernorm_fp8_dot_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0]) amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0]) amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]): if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name) wgrad = jax.lax.psum(wgrad, dp_axis_name)
......
...@@ -101,7 +101,6 @@ def fp8_ln_mlp( ...@@ -101,7 +101,6 @@ def fp8_ln_mlp(
ln_scale: jnp.ndarray, ln_scale: jnp.ndarray,
ln_bias: jnp.ndarray, ln_bias: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
amax_history_idx: int,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
epsilon: float = 1e-6, epsilon: float = 1e-6,
...@@ -130,8 +129,8 @@ def fp8_ln_mlp( ...@@ -130,8 +129,8 @@ def fp8_ln_mlp(
assert activations == ('gelu', 'linear') assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE: if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, amax_history_idx, activations, epsilon, fwd_dtype, scale_inv, layernorm_type, activations, epsilon, fwd_dtype, bwd_dtype,
bwd_dtype, contracting_dims, major_sharding_type, "", "") contracting_dims, major_sharding_type, "", "")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -177,7 +176,6 @@ def fp8_ln_mlp( ...@@ -177,7 +176,6 @@ def fp8_ln_mlp(
partial_fp8_mlp = partial(_fp8_mlp, partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
amax_history_idx=amax_history_idx,
activations=activations, activations=activations,
epsilon=epsilon, epsilon=epsilon,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
...@@ -198,10 +196,10 @@ def fp8_ln_mlp( ...@@ -198,10 +196,10 @@ def fp8_ln_mlp(
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, amax_history_idx: int, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType, activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]], bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str): major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
...@@ -215,7 +213,6 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, ...@@ -215,7 +213,6 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx,
activations, activations,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -238,7 +235,6 @@ def _fp8_mlp_fwd( ...@@ -238,7 +235,6 @@ def _fp8_mlp_fwd(
scale, scale,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
amax_history_idx, # pylint: disable=unused-argument
activations, activations,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -266,6 +262,8 @@ def _fp8_mlp_fwd( ...@@ -266,6 +262,8 @@ def _fp8_mlp_fwd(
kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1)) kernel_1_ = jnp.reshape(kernel_1, (kernel_1_pre_size, -1))
kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1)) kernel_2_ = jnp.reshape(kernel_2, (kernel_2_pre_size, -1))
amax = FP8Helper.update_amax_history(amax)
gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) gemm1_input_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm1_input_idx] input_amax = amax[gemm1_input_idx]
...@@ -335,7 +333,6 @@ def _fp8_mlp_fwd( ...@@ -335,7 +333,6 @@ def _fp8_mlp_fwd(
def _fp8_mlp_bwd( def _fp8_mlp_bwd(
layernorm_type, layernorm_type,
amax_history_idx,
activations, # pylint: disable=unused-argument activations, # pylint: disable=unused-argument
epsilon, epsilon,
fwd_dtype, fwd_dtype,
...@@ -405,12 +402,12 @@ def _fp8_mlp_bwd( ...@@ -405,12 +402,12 @@ def _fp8_mlp_bwd(
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
amax = amax.at[gemm1_input_idx, amax_history_idx].set(ln_out_amax[0]) amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0])
amax = amax.at[gemm1_kernel_idx, amax_history_idx].set(kernel_1_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, amax_history_idx].set(dgelu_amax[0]) amax = amax.at[gemm1_grad_idx, 0].set(dgelu_amax[0])
amax = amax.at[gemm2_input_idx, amax_history_idx].set(gated_gelu_amax[0]) amax = amax.at[gemm2_input_idx, 0].set(gated_gelu_amax[0])
amax = amax.at[gemm2_kernel_idx, amax_history_idx].set(kernel_2_amax[0]) amax = amax.at[gemm2_kernel_idx, 0].set(kernel_2_amax[0])
amax = amax.at[gemm2_grad_idx, amax_history_idx].set(grad_amax[0]) amax = amax.at[gemm2_grad_idx, 0].set(grad_amax[0])
if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP): if major_sharding_type in (MajorShardingType.DP, MajorShardingType.DPTP):
wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name) wgrad_1 = jax.lax.psum(wgrad_1, dp_axis_name)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from .dot import fp8_dot
from .fp8 import FP8GemmPackage, FP8Helper
from .layernorm import canonicalize_layernorm_type
from .layernorm import layernorm, layernorm_fp8_dot
from .mlp import fp8_ln_mlp, geglu
from .sharding import infer_sharding_type
from .softmax import is_softmax_kernel_available
from .sharding import MajorShardingType, ShardingType
from .softmax import softmax, SoftmaxType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
def _canonicalize_tuple(x):
if isinstance(x, Iterable):
return tuple(x)
return (x,)
def _create_layernorm_parameters(layernorm_type, shape, scale_init, scale_axes, bias_init,
bias_axes, dtype):
scale = nn_partitioning.param_with_axes('scale',
scale_init,
shape,
jnp.float32,
axes=scale_axes)
scale = jnp.asarray(scale, dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'layernorm':
bias = nn_partitioning.param_with_axes('ln_bias',
bias_init,
shape,
jnp.float32,
axes=bias_axes)
bias = jnp.asarray(bias, dtype)
else:
assert layernorm_type == 'rmsnorm'
bias = None
return scale, bias
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def _combine_biases(*masks: List[Array]):
"""Combine attention biases."""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
return mask
class Softmax(nn.Module):
r"""
Applies softmax over a mini-batch of inputs.
The inputs's shape should be [batch, heads, q_seqlen, k_seqlen].
Parameters
----------
scale_factor : float, default = 1.0
scale the inputs along the last dimension before running softmax.
softmax_type : SoftmaxType, default = 'layernorm'
indicate the type of softmax.
Optimization parameters
-----------------------
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
@nn.compact
def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray:
batch = inputs.shape[0]
heads = inputs.shape[1]
q_seqlen = inputs.shape[2]
k_seqlen = inputs.shape[3]
dtype = inputs.dtype
logits = inputs
if (self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype)):
if bias is not None:
logits = logits + bias.astype(dtype)
mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED:
mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type,
self.sharding_type)
else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
if bias is not None:
attention_bias = _combine_biases(attention_bias, bias)
if attention_bias is not None:
logits = logits + attention_bias.astype(dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen,
dtype):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED,
self.sharding_type)
else:
outputs = jax_nn.softmax(logits)
return outputs
class LayerNorm(nn.Module):
r"""
Applies layer normalization over a mini-batch of inputs.
There are two types of normalization supported by this module,
regular and root mean square layer Normalization.
The regular layer normalization is as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size of each input sample.
The root mean square layer normalization (RMSNorm) is as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma
.. math::
RMS = \sqrt{\mathrm{E}[x^2]}
:math:`\gamma` is learnable affine transform parameters of
size of each input sample.
Parameters
----------
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
bias_init : Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`layernorm_type='layernorm'`.
bias_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`layernorm_type='layernorm'`.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones
scale_axes: Tuple[str, ...] = ('embed',)
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ('embed',)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
sharding_type: ShardingType = ShardingType.SINGLE
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Applies layer normalization to the input :attr:`inputs`.
Parameters
----------
inputs : jax.numpy.ndarray
Input tensors.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
"""
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.bias_init, self.bias_axes, self.dtype)
return layernorm(x,
scale,
ln_bias,
self.layernorm_type,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
class TransformerEngineBase(nn.Module):
"""
Base class of transformer engine
"""
@staticmethod
def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
"""
Get the FP8 metas
"""
num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
axes = ('fp8_meta_axis', 'fp8_meta_history')
fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_MAX_NAME,
FP8Helper.generate_fp8_max_array,
num_of_meta,
axes=axes)
fp8_metas_amax = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_AMAX_NAME,
jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32,
axes=axes)
fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_SCALE_NAME,
jnp.ones, (num_of_meta, 1),
jnp.float32,
axes=axes)
fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_SCALE_INV_NAME,
jnp.ones, (num_of_meta, 1),
jnp.float32,
axes=axes)
return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value
@staticmethod
def get_fp8_gemm_package(num_of_gemm: int, inputs: jnp.ndarray,
kernels: List[jnp.ndarray]) -> FP8GemmPackage:
"""
Get the FP8 metas
"""
assert num_of_gemm == len(kernels)
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return FP8GemmPackage(num_of_gemm, inputs, kernels, fp8_max, fp8_metas_amax,
fp8_metas_scale, fp8_metas_scale_inv)
class DenseGeneral(TransformerEngineBase):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
Parameters
----------
features : Union[Iterable[int], int]
the hidden size of each output sample.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights.
kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""
Apply the linear transformation to the input.
Parameters
----------
inputs : jax.numpy.ndarray
Input tensors.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
"""
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,),
self.dtype,
axes=self.bias_axes)
else:
bias = None
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, inputs, [kernel])
y = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
class LayerNormDenseGeneral(TransformerEngineBase):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
----------
features : Union[Iterable[int], int]
the hidden size of each output sample.
enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights.
kernel_axes : Tuple[str, ...], default = ()
the name of axes used to shard the weights with a corresponding mesh.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
bias_axes: Tuple[str, ...], default = ()
the name of axes used to shard bias with a corresponding mesh,
only works when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None
the factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
sharding_type : ShardingType, default = ShardingType.SINGLE
indicate the sharding pattern.
"""
features: Union[Iterable[int], int]
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones
scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""
Apply layer normalization to the input followed by a linear transformation.
Parameters
----------
inputs: jax.numpy.ndarray
Input tensor.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None.
"""
ln_output = None
fuse_layernorm = FP8Helper.enable_fp8(
) and not self.return_layernorm_output and self.enable_layernorm
if self.enable_layernorm:
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.ln_bias_init, self.ln_bias_axes,
self.dtype)
if not fuse_layernorm:
y = layernorm(inputs,
scale,
ln_bias,
layernorm_type=self.layernorm_type,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
else:
assert not self.return_layernorm_output
y = inputs
else:
y = inputs
if self.return_layernorm_output:
ln_output = y
# DenseGeneral
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(1, y, [kernel])
if not fuse_layernorm:
z = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
z = layernorm_fp8_dot(fp8_gemm_package,
scale,
ln_bias,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
else:
kernel = jnp.asarray(kernel, self.dtype)
z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init, (self.features,),
self.dtype,
axes=self.bias_axes)
if bias is not None:
z += jnp.reshape(bias, (1,) * (z.ndim - 1) + (-1,))
if self.depth_scaling is not None:
z = z / self.depth_scaling
return z, ln_output # dense_output, layer_norm_output
class LayerNormMLP(TransformerEngineBase):
r"""
Applies layer normalization on the input followed by the MLP module,
consisting of 2 successive linear transformations, separated by given activations.
Parameters
----------
intermediate_dim: int, default = 2048
intermediate size to which input samples are projected.
enable_layernorm: bool, default = True
indicate whether to enable layer normalization before linear transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
epsilon : float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
used for initializing scale factors :math:`\gamma`.
scale_axes : Tuple[str, ...], default = ('embed', )
the name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
only works when :attr:`enable_layernorm=True`.
ln_bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing shift factors :math:`\beta`,
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
ln_bias_axes: Tuple[str, ...], default = ('embed', )
The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh.
only works when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weight of both linear transformations.
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
the name of axes used to shard the weights with a corresponding mesh for
the weight of the first linear transformations.
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
the name of axes used to shard the weights with a corresponding mesh for
the weight of the second linear transformations.
use_bias: bool, default = False
indicate whether to enable bias shifting.
if set to False, the layer will not learn an additive bias.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias, only works when :attr:`use_bias=True`.
bias_axes_1: Tuple[str, ...], default = ('mlp',)
the name of axes used to shard bias with a corresponding mesh for
the weight of the first linear transformations.
only works when :attr:`use_bias=True`.
bias_axes_2: Tuple[str, ...], default = ('embed',)
the name of axes used to shard bias with a corresponding mesh for
the weight of the second linear transformations.
only works when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
the sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
intermediate_dropout_rate: float, default = 0.1
dropout probability for the dropout op after the :attr:`activations`.
axis: Union[Iterable[int], int], default = -1
a integer of tuple with axes to apply the transformation on.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
major_sharding_type : MajorShardingType, default = MajorShardingType.SINGLE
indicate the sharding pattern.
"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones
scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
kernel_init: Initializer = None
kernel_axes_1: Tuple[str, ...] = ('embed', 'act', 'mlp')
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ('mlp',)
bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
Parameters
----------
inputs: jax.numpy.ndarray
Input tensor.
deterministic: bool, default = False
Disable dropout ops if set to True.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors.
ln_outputs: jax.numpy.ndarray
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this woulb be None.
"""
ln_output = None
fuse_layernorm = FP8Helper.enable_fp8(
) and not self.return_layernorm_output and self.enable_layernorm
use_fused_ln_mlp = fuse_layernorm \
and (not self.use_bias) and self.activations == ('gelu', 'linear') \
and (self.intermediate_dropout_rate < 1e-3)
first_sharding_type, second_sharding_type = infer_sharding_type(self.major_sharding_type)
# LayerNorm
if self.enable_layernorm:
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.ln_bias_init, self.ln_bias_axes,
self.dtype)
if not fuse_layernorm:
y = layernorm(inputs,
scale,
ln_bias,
layernorm_type=self.layernorm_type,
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
else:
assert not self.return_layernorm_output
y = inputs
else:
y = inputs
if self.return_layernorm_output:
ln_output = y
def kernel_1_init(key, num_kernels, stack_axis, *init_args):
kernels = []
for _ in range(num_kernels):
key, init_key = jax_random.split(key)
kernels.append(self.kernel_init(init_key, *init_args))
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)
num_of_gemm = 2
if use_fused_ln_mlp:
num_activations = len(self.activations)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, inputs.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_1_shape = tuple(inputs.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2_param_shape = (self.intermediate_dim, np.prod(hidden_size_tuple))
kernel_2 = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_2_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
contract_ind = tuple(range(0, len(axis)))
fp8_gemm_package = \
TransformerEngineBase.get_fp8_gemm_package(num_of_gemm, y, [kernel_1, kernel_2])
out = fp8_ln_mlp(fp8_gemm_package,
scale,
ln_bias,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE,
epsilon=self.epsilon,
contracting_dims=(axis, contract_ind),
major_sharding_type=self.major_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
activations=self.activations)
else: # not use_fused_ln_mlp
def fp8_meta_generator():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = (None, None, None,
None)
if FP8Helper.enable_fp8():
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
return fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
fp8_meta_generator()
# DenseGeneral 1
activations = []
num_activations = len(self.activations)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
intermediate_dim = _canonicalize_tuple((num_activations, self.intermediate_dim))
kernel_shape = tuple(y.shape[ax] for ax in axis) + intermediate_dim
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel = nn_partitioning.param_with_axes('wi_kernel',
kernel_1_init,
num_activations,
-2,
kernel_1_each_shape,
jnp.float32,
axes=self.kernel_axes_1)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
fp8_gemm_package = FP8GemmPackage(
1, y, [kernel], fp8_max[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_amax[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale[:FP8Helper.NUM_META_PER_GEMM, :],
fp8_metas_scale_inv[:FP8Helper.NUM_META_PER_GEMM, :])
if not fuse_layernorm:
x = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
x = layernorm_fp8_dot(fp8_gemm_package,
scale,
ln_bias,
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
else: # not enable fp8
kernel = jnp.asarray(kernel, self.dtype)
x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, (self.intermediate_dim,),
self.dtype,
axes=self.bias_axes_1)
x += jnp.reshape(bias, (1,) * (x.ndim - 1) + (-1,))
if self.activations == ('gelu', 'linear'):
z = geglu(x,
contracting_dims=(-2, -1),
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
z, deterministic=deterministic) # Broadcast along length.
# DenseGeneral 2
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, z.ndim)
kernel_shape = tuple(z.shape[ax] for ax in axis) + hidden_size_tuple
kernel_param_shape = (np.prod([z.shape[ax] for ax in axis]), np.prod(hidden_size_tuple))
kernel = nn_partitioning.param_with_axes('wo_kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes_2)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis)))
if FP8Helper.enable_fp8():
fp8_gemm_package = FP8GemmPackage(
1, z, [kernel], fp8_max[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_amax[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale[FP8Helper.NUM_META_PER_GEMM:, :],
fp8_metas_scale_inv[FP8Helper.NUM_META_PER_GEMM:, :])
out = fp8_dot(fp8_gemm_package,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
sharding_type=second_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
out = lax.dot_general(z, kernel, ((axis, contract_ind), ((), ())))
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias',
self.bias_init, (hidden_size,),
self.dtype,
axes=self.bias_axes_2)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))
return out, ln_output # Output, layner_norm_output
...@@ -694,9 +694,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -694,9 +694,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name}] in_axes = [{dp_dim: dp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({ out_axes = in_axes[0]
dp_dim: dp_axis_name
}), {dp_axis_name: dp_mesh_axis}, input_new_shapes, [input_shape]) return ShardingMeta(tuple(in_axes), out_axes, {dp_axis_name: dp_mesh_axis},
input_new_shapes, [input_shape])
def get_tp_col_sharding_meta(self, def get_tp_col_sharding_meta(self,
input_shape: Tuple, input_shape: Tuple,
...@@ -764,9 +765,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -764,9 +765,10 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{tp_dim: tp_axis_name}] in_axes = [{tp_dim: tp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
return ShardingMeta(tuple(in_axes), ({ out_axes = in_axes[0]
tp_dim: tp_axis_name
}), {tp_axis_name: tp_mesh_axis}, input_new_shapes, [input_shape]) return ShardingMeta(tuple(in_axes), out_axes, {tp_axis_name: tp_mesh_axis},
input_new_shapes, [input_shape])
@staticmethod @staticmethod
def _get_dptp_sharding_meta(input_shape: Tuple, def _get_dptp_sharding_meta(input_shape: Tuple,
...@@ -794,7 +796,7 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator): ...@@ -794,7 +796,7 @@ class SoftmaxShardingMetaGenerator(ShardingMetaGenerator):
in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}] in_axes = [{dp_dim: dp_axis_name, tp_dim + 1: tp_axis_name}]
input_new_shapes = [input_new_shape] input_new_shapes = [input_new_shape]
out_axes = in_axes out_axes = in_axes[0]
return ShardingMeta(tuple(in_axes), out_axes, { return ShardingMeta(tuple(in_axes), out_axes, {
dp_axis_name: dp_mesh_axis, dp_axis_name: dp_mesh_axis,
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
from .cpp_extensions import scaled_softmax_fwd
from .cpp_extensions import scaled_softmax_bwd
from .cpp_extensions import scaled_masked_softmax_fwd
from .cpp_extensions import scaled_masked_softmax_bwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_fwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from .sharding import get_softmax_sharding_meta, ShardingType
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
class SoftmaxType(Enum):
"""SoftmaxType."""
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int,
k_seqlen: int, dtype: jnp.dtype):
"""check softmax available"""
if softmax_type is SoftmaxType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype)
raise NotImplementedError
def softmax(inputs: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
tp_dim_index: int = 1):
"""
Softmax wrapper
"""
assert dp_dim_index == 0, \
"Only softmax support batch dim in the first place currently."
assert tp_dim_index == 1, \
"Only softmax support head dim in the second place currently."
assert mask is None or mask.shape[tp_dim_index] == 1
if sharding_type is ShardingType.SINGLE:
outputs = _softmax(inputs, mask, scale_factor, softmax_type)
else:
dp_axis_name = "batch"
tp_axis_name = "model"
sharding_meta = get_softmax_sharding_meta(sharding_type,
inputs.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
mask_ = mask
mask_in_axis = {}
if mask_ is not None:
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
# If mask is head broadcastable (heads == 1),
# then it equals to DP sharding.
mask_sharding_meta = get_softmax_sharding_meta(ShardingType.DP,
mask_.shape,
dp_dim=dp_dim_index,
tp_dim=tp_dim_index,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
mask_ = jnp.reshape(mask_, mask_sharding_meta.input_shapes[0])
mask_in_axis = mask_sharding_meta.in_axes[0]
partial_softmax = partial(_softmax, scale_factor=scale_factor, softmax_type=softmax_type)
in_axes = (sharding_meta.in_axes[0], mask_in_axis)
outputs = xmap_runner(partial_softmax, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, mask_))
outputs = jnp.reshape(outputs, sharding_meta.output_shapes[0])
return outputs
@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
def _softmax(inputs, mask, scale_factor, softmax_type):
output, _ = _softmax_fwd(inputs, mask, scale_factor, softmax_type)
return output
def _softmax_fwd(inputs, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
outputs = scaled_masked_softmax_fwd(inputs, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = scaled_upper_triang_masked_softmax_fwd(inputs, scale_factor)
else:
outputs = scaled_softmax_fwd(inputs, scale_factor)
return outputs, (outputs, mask)
def _softmax_bwd(scale_factor, softmax_type, ctx, grad_outputs):
softmax_outputs, mask = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
dgrad = scaled_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
else:
dgrad = scaled_softmax_bwd(grad_outputs, softmax_outputs, scale_factor)
return (dgrad, None)
_softmax.defvjp(_softmax_fwd, _softmax_bwd)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Wrapper module for Transformer related layers with FP8 support.
"""
import functools
from enum import Enum
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from .softmax import SoftmaxType
from .sharding import infer_major_sharding_type, infer_sharding_type
from .sharding import global_shard_resource, ShardingType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
LogicalRules = Sequence[Tuple[str, Union[str, None]]]
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
drop_path_shape = list(range(0, len(shape)))
drop_path_shape.pop(batch_dim)
return drop_path_shape
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
"""
Extend the given Flax logical axis rules with the pre-defined TransformerLayer's
logical axis rules.
.. note::
We currently only support single, data parallelism and standard tensor parallelism
logical axis rules for performance reasons.
.. warning::
Please make sure ShardingResource is set via fp8_autocast before calling this function.
Parameters
----------
rules : Sequence[Tuple[str, Union[str, None]]]
the base Flax logical axis rules to extend.
Returns
-------
extended_rules : Sequence[Tuple[str, Union[str, None]]]
the extended Flax logical axis rules.
"""
rules_map = {}
for item in rules:
assert len(item) == 2, \
"The logical axis rule should be like (axis_name, mesh_axis_name)."
key = item[0]
val = item[1]
assert isinstance(key, str), \
f"Thie axis_name should be str, but got {type(key)}."
assert isinstance(val, str) or (val is None), \
f"Thie mesh_axis_name should be str or None, but got {type(val)}."
rules_map[key] = val
gsr = global_shard_resource()
te_logical_axis_rules = (('batch', gsr.dp_resource), ('embed', None), ('mlp', gsr.tp_resource),
('heads', gsr.tp_resource), ('kv', None), ('qkv_dim', None),
('kv_dim', None), ('joined_kv', gsr.tp_resource), ('act', None),
('relpos_buckets', None), ('length', None))
extended_rules = [*rules]
for item in te_logical_axis_rules:
key = item[0]
val = item[1]
if key in rules_map:
assert rules_map[key] == val, \
f"The rule diverged between TE and given rule." \
f"Axis:{key} map to {rules_map[key]} in the given" \
f" rules, but {val} in TE's rules."
else:
extended_rules.append(item)
return tuple(extended_rules)
def _merge_mask(func, *masks: Optional[Array]):
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = func(mask, other_mask)
return mask
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks."""
func = jnp.logical_and
return _merge_mask(func, *masks).astype(dtype)
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases."""
func = lambda a, b: a + b
return _merge_mask(func, *masks)
def core_attention(query: Array,
key: Array,
value: Array,
transpose_batch_sequence: bool,
softmax_type: SoftmaxType = SoftmaxType.SCALED,
softmax_sharding_type: ShardingType = ShardingType.SINGLE,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: DType = jnp.float32,
float32_logits: bool = False):
"""Core attention"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = Softmax(softmax_type=softmax_type,
sharding_type=softmax_sharding_type)(attn_weights, mask, bias)
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
if transpose_batch_sequence:
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class AttentionType(Enum):
"""TransformerLayerType."""
PADDING = "padding_attention"
CAUSAL = "causal_attention"
class MultiHeadAttention(nn.Module):
r"""
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
Parameters
----------
head_dim : int
the hidden dimension of each attention heads.
num_heads : int
the number of attention heads
dropout_rate : float, default = 0.0
dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that
for generate Dropout masks in the core attention.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights.
use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections.
if set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections, only works when :attr:`use_bias=True`.
apply_residual_connection_post_layernorm : bool, default = False
indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False
indicate if apply a layer normalization in the end of MHA.
attn_type: AttentionType, defult = AttentionType.PADDING
indicate the format of the attentino mask in the core attention.
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
fuse_qkv: bool, default = True
if set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}`
float32_logits : bool, default = False
whether to compute attention logits in float32.
"""
head_dim: int
num_heads: int
dropout_rate: float = 0.
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
kernel_init: Initializer = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING
dtype: DType = jnp.float32
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False # computes logits in float32 for stability.
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
super().__post_init__()
@nn.compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False) -> Array:
"""Applies multi-head dot product attention on the input data."""
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
def query_init(*args):
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
def qkv_init(key, shape, dtype):
assert len(shape) == 3
assert shape[-2] == 3
q_key, k_key, v_key = jax_random.split(key, num=3)
q_shape = (shape[0], shape[-1])
k_shape = (shape[0], shape[-1])
v_shape = (shape[0], shape[-1])
q_kernel = query_init(q_key, q_shape, dtype)
k_kernel = self.kernel_init(k_key, k_shape, dtype)
v_kernel = self.kernel_init(v_key, v_shape, dtype)
return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype)
def kv_init(key, shape, dtype):
assert len(shape) == 3
assert shape[-2] == 2
k_key, v_key = jax_random.split(key)
k_shape = (shape[0], shape[-1])
v_shape = (shape[0], shape[-1])
k_kernel = self.kernel_init(k_key, k_shape, dtype)
v_kernel = self.kernel_init(v_key, v_shape, dtype)
return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype)
first_sharding_type, second_sharding_type = infer_sharding_type()
residual = inputs_q
if self.fuse_qkv:
if inputs_q is inputs_kv:
qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
scale_axes=('embed',),
kernel_axes=('embed', 'qkv_dim', 'joined_kv'),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name='qkv',
dtype=self.dtype)(inputs_q)
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
query = jnp.reshape(query, (*query.shape[:-2], -1))
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
if self.scale_attn_logits:
query = query / depth_scaling
else:
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=(2, self.num_heads * self.head_dim),
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'kv_dim', 'joined_kv'),
kernel_init=kv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name='kv',
dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [
1,
], axis=-2)
key = jnp.reshape(key, (*key.shape[:-2], -1))
value = jnp.reshape(value, (*value.shape[:-2], -1))
else:
kv_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
sharding_type=first_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
depth_scaling=depth_scaling if self.scale_attn_logits else None,
scale_axes=('embed',),
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
if inputs_q is inputs_kv:
assert ln_out is not None
inputs_kv = ln_out
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
qkv_sharding_constraint = \
('length', 'batch', 'heads','kv') \
if self.transpose_batch_sequence \
else ('batch', 'length', 'heads', 'kv')
query = nn_partitioning.with_sharding_constraint(query, qkv_sharding_constraint)
key = nn_partitioning.with_sharding_constraint(key, qkv_sharding_constraint)
value = nn_partitioning.with_sharding_constraint(value, qkv_sharding_constraint)
if decode:
is_initialized = self.has_variable('cache', 'cached_key')
# TODO (Ming Huang): Check performance on GPU withou swap dimensions # pylint: disable=fixme
def swap_dims(x):
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
value.dtype)
cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
# Sanity shape check of cached key against input query.
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
f"expected query shape {expected_shape} instead got {query.shape}.")
cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
mask = combine_masks(
mask, jnp.broadcast_to(jnp.arange(length) <= cur_index, (batch, 1, 1, length)))
if bias is not None:
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
dropout_rng = None
if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng(self.dropout_rng_name)
softmax_type = SoftmaxType.SCALED
if self.attn_type is AttentionType.PADDING:
if mask is not None:
softmax_type = SoftmaxType.SCALED_MASKED
else:
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
x = core_attention(query,
key,
value,
transpose_batch_sequence=self.transpose_batch_sequence,
softmax_type=softmax_type,
softmax_sharding_type=first_sharding_type,
mask=mask,
bias=bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
attn_context_sharding_constraint = \
('length', 'batch', 'joined_kv') \
if self.transpose_batch_sequence \
else ('batch', 'length', 'joined_kv')
x = nn_partitioning.with_sharding_constraint(x, attn_context_sharding_constraint)
out = DenseGeneral(features=inputs_q.shape[-1],
sharding_type=second_sharding_type,
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'),
use_bias=self.use_bias,
bias_init=self.bias_init,
dtype=self.dtype,
name='out')(x)
return out, residual
class RelativePositionBiases(nn.Module):
"""
T5-style relative positional embeddings to the attention logits.
Parameters
----------
num_buckets : int
the number of buckets to bucket distances between key and query positions into.
max_distance : int
the maximum distance before everything is lumped into the last
distance bucket.
num_attention_heads : int
number of attention heads in the transformer layer.
embedding_init : Initializer, default = flax.linen.linear.default_embed_init
used for initializing relative embedding tables.
embedding_axes : Tuple[str, ...], default = ('heads', 'relpos_buckets')
the name of axes used to shard embedding attention bias with a corresponding mesh.
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
"""
num_buckets: int
max_distance: int
num_attention_heads: int
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ('heads', 'relpos_buckets')
dtype: DType = jnp.float32
@nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
"""
Generate relative position embedding attention biases.
Parameters
----------
q_seqlen : int
the sequence length of query.
k_seqlen : int
the sequence length of key.
bidirectional : bool, default = True
indicate whether to allow positive memory-query relative position
embeddings.
Returns
-------
output: jax.numpy.ndarray
An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`.
"""
context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position
# Compute relative position bucket
rp_bucket = 0
negative_rp = -relative_position
rpb_num_buckets = self.num_buckets
if bidirectional:
rpb_num_buckets //= 2
rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets
negative_rp = np.abs(negative_rp)
else:
negative_rp = np.maximum(negative_rp, 0)
rpb_max_exact = rpb_num_buckets // 2
rpb_is_small = negative_rp < rpb_max_exact
rpb_val_if_large = rpb_max_exact + (
np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) /
np.log(self.max_distance / rpb_max_exact) *
(rpb_num_buckets - rpb_max_exact)).astype(np.int32)
rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1)
rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)
# Compute relative attention bias
relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding',
self.embedding_init, (self.num_attention_heads, self.num_buckets),
jnp.float32,
axes=self.embedding_axes)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
values = lax.dot_general(relative_attention_bias, rp_bucket_one_hot,
(((1,), (0,)), ((), ())))
return values[jnp.newaxis, ...]
class TransformerLayerType(Enum):
"""TransformerLayerType."""
ENCODER = "encoder"
DECODER = "decoder"
class TransformerLayer(nn.Module):
r"""
TransformerLayer is made up of a relative embedding,
an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”.
Parameters
----------
hidden_size: int, default = 512
the hidden size of each input sample.
mlp_hidden_size: int, default = 2048
intermediate size to which input samples are projected.
num_attention_heads: int, default = 8
number of attention heads in the transformer layer.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
a value added to the denominator of layer normalization for numerical stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims: Sequence[int], default = ()
dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
dropout_rng_name: str, default = 'dropout'
the key in given RNGs via flax.linen.Module.apply that for
generate Dropout masks in the Multi-Head Attention.
mha_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
used for initializing weights of QKV and Output projection weights.
mlp_kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
used for initializing weights of FC1 and FC2 layers.
mlp_activations: Sequence[str], default = ('relu', )
the sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
use_bias: bool, default = False
indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
if set to False, the layer will not learn additive biases.
bias_init: Initializer, default = flax.linen.initializers.zeros
used for initializing bias of QKVO projections,
FC1 and FC2, only works when :attr:`use_bias=True`.
apply_residual_connection_post_layernorm: bool, default = False
if set to True, residual connections are taken from the output
of layer norm (default is taken from input of layer norm)
output_layernorm: bool, default = False
if set to True, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
float32_attention_logits: bool, default = False
if set to True, attention logits are executed in jax.numpy.float32.
layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER
if set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
enable_relative_embedding: bool, default = True
whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
the module for relative embedding execution, only works when
:attr:`enable_relative_embedding=True`. Default is None, which will create
an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`.
Default: RelativePositionBiases( num_buckets=32, max_distance=128,
num_attention_heads=self.num_attention_heads, dtype=self.dtype,
embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
Optimization parameters
-----------------------
dtype :jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
drop_path: float, default = 0.0
when > 0.0, applies stochastic depth per sample in the main
path of the residual block.
fuse_qkv_params: bool, default = True
if set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
scale_attn_logits: bool, default = False
indicate whether to scale attention logits.
if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`,
else :math:`Q*K`
scaled_query_init: bool, default = `True`
whether to scale WQ on initilization by :math:`\sqrt{head_dim}`
"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
dropout_rng_name: str = 'dropout'
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
scale_attn_logits: bool = False
scaled_query_init: bool = True
def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
super().__post_init__()
@nn.compact
def __call__(self,
inputs: Array,
encoded: Array = None,
attention_mask: Array = None,
encoder_decoder_mask: Array = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None):
"""
Transformer Layer: attention block and a feedforward network (MLP)
Parameters
----------
inputs : jax.numpy.ndarray
Input tensor.
encoded : jax.numpy.ndarray, default = None
Output tensors of the encoder block to be fed into the decoder block if using
:attr:`layer_type=TransformerLayerType.DECODER`.
attention_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out self-attention softmax input.
encoder_decoder_mask : jax.numpy.ndarray, default = None
Boolean tensor used to mask out cross-attention softmax input when
:attr:`layer_type=TransformerLayerType.DECODER`.
deterministic: bool, default = False
Disables dropout layers if set to True.
decode: bool,default = False
Indicate whether to prepare and use an autoregressive cache
in Multi-head attention (MHA).
max_decode_length : bool, default = None
The maximum length to generate relative embedding biases when
:attr:`layer_type=TransformerLayerType.DECODER` and
:attr:`enable_relative_embedding=True`.
Returns
-------
outputs : jax.numpy.ndarray
Output tensors of this transformer block.
"""
assert self.layer_type in TransformerLayerType, \
"layer_type should be one of TransformerLayerType" \
f", but got {self.layer_type}."
assert self.hidden_size % self.num_attention_heads == 0, \
"hidden_size should be multiples of num_attention_heads" \
f", but got {self.hidden_size=} and {self.num_attention_heads=}."
assert self.layer_type == TransformerLayerType.DECODER or \
(self.layer_type == TransformerLayerType.ENCODER and decode is False), \
"decode should be False when layer_type == TransformerLayerType.ENCODER."
head_dim = self.hidden_size // self.num_attention_heads
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
attn_bias = None
if self.enable_relative_embedding:
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
else:
rel_emb = self.relative_embedding
if self.layer_type == TransformerLayerType.ENCODER:
attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
else:
if decode and max_decode_length:
l = max_decode_length
else:
l = inputs.shape[sequence_dim]
attn_bias = rel_emb(l, l, False)
assert inputs.ndim == 3
self_attn_type = None
# Make name be the exactly same as T5X, since names would affect
# RNGKey during init and apply. Myabe no need in the feature.
if self.layer_type == TransformerLayerType.ENCODER:
mha_name = 'attention'
self_attn_type = AttentionType.PADDING
else:
mha_name = 'self_attention'
self_attn_type = AttentionType.CAUSAL
assert self_attn_type is not None
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self_attn_type,
fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name=mha_name)(inputs,
inputs,
attention_mask,
attn_bias,
deterministic=deterministic,
decode=decode)
def hidden_dropout(x, deterministic):
assert isinstance(self.hidden_dropout_dims, Sequence)
x_shape_len = len(x.shape)
for dims in self.hidden_dropout_dims:
assert -x_shape_len < dims < x_shape_len
return nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic)
x = hidden_dropout(x, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
x = x + residual
mlp_input = x
if self.layer_type == TransformerLayerType.DECODER:
assert encoded is not None, \
"encoded is required when layer_type == TransformerLayerType.DECODER."
y, residual = MultiHeadAttention(
num_heads=self.num_attention_heads,
dtype=self.dtype,
head_dim=head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA.
attn_type=AttentionType.PADDING,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
name='encoder_decoder_attention')(x,
encoded,
encoder_decoder_mask,
deterministic=deterministic)
y = hidden_dropout(y, deterministic)
mlp_input = y + residual
# MlpBlock
residual = mlp_input
z, ln_out = LayerNormMLP(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
intermediate_dropout_rate=self.hidden_dropout,
dtype=self.dtype,
scale_axes=('embed',),
kernel_init=self.mlp_kernel_init,
kernel_axes_1=('embed', 'act', 'mlp'),
kernel_axes_2=('mlp', 'embed'),
use_bias=self.use_bias,
bias_init=self.bias_init,
name='mlp',
)(mlp_input, deterministic=deterministic)
if self.apply_residual_connection_post_layernorm:
assert ln_out is not None
residual = ln_out
z = hidden_dropout(z, deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
z = z + residual
if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type,
scale_axes=('embed',),
bias_axes=('embed',),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
epsilon=self.layernorm_epsilon,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z)
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment