Unverified Commit dff11340 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Rewrite the Format of FP8 Meta and Remove unused ShardingTypes. (#842)



* Reformat FP8 Meta

1. Reformat FP8 meta to be one-set-per-tensor.
2. Remove fp8_max and scale_inv.
3. Remove unused functions in fp8.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix unit-tests
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Remove ShardingType and MajorShardingType
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix lint errors
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed unittests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Rename few variables.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Add jit to update_amax_list
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed naming error in LayernormMLP
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed bugs in test_distributed_layernorm_mlp.py
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent dec3ef1d
......@@ -27,7 +27,6 @@ Modules
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapifunction:: transformer_engine.jax.update_fp8_metas
.. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
......
......@@ -10,7 +10,7 @@ This example uses Transformer Encoder to demonstrate the Transformer Engine usag
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. And then, call `te.update_fp8_metas` to update FP8 metadata. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step.
4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
5. Evaluating process: Same as the training process, the FP8 metadata needs to be in var_collect and fill it into a loss function, if enabling FP8 computing.
......@@ -27,7 +27,7 @@ python test_single_gpu_encoder.py --use-fp8
2. In order to let JAX know how to do sharding, the `device_mesh` needs to be defined and each axis need to be named. A common way to annotate axis names is `data` which means the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. And the first argument of `te.ShardingResource` is the name of the device axis which is used for data parallelism.
3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case. But te.DenseGeneral is based on [XLA custom-call](https://www.tensorflow.org/xla/custom_call) and [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html), the `sharding_type` must be set to map weights and xmap correctly.
3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
......
......@@ -8,13 +8,12 @@ This example uses MNIST training to demonstrate the Transformer Engine usage. Th
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. After getting loss and gradient, we also need to call `te.update_fp8_metas` to update FP8 metadata in the `update_model` routine. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function.
6. Additional options: The `te.fp8_autocast` context manager has additional options
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options. **Noted** that FP8 metadata is now the responsibility of the user to update (i.e., manually calling `te.update_fp8_metas`). The JAX version of Transformer Engine cannot update FP8 metadata on its own.
* Sharding Resource: tell Transformer Engine how to make data parallelism and tensor parallelism. We will introduce it more in Encoder examples.
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options.
## Run ##
......
......@@ -4,8 +4,6 @@
[pytest]
filterwarnings=
ignore:sharding_type of.*:DeprecationWarning
ignore:major_sharding_type of.*:DeprecationWarning
ignore:Fused attention is not enabled.*:UserWarning
ignore:The hookimpl.*:DeprecationWarning
ignore:xmap is an experimental feature and probably has bugs!
......
This diff is collapsed.
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
import pytest
from typing import Callable, Sequence, Union
from typing import Callable, List, Sequence, Union
import jax
import jax.numpy as jnp
......@@ -20,7 +20,6 @@ from transformer_engine.jax.sharding import HIDDEN_AXES, HIDDEN_TP_AXES, \
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
......@@ -30,6 +29,7 @@ DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
INTERMEDIATE = 16
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
configs = []
......@@ -56,10 +56,10 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type),
INTERMEDIATE), dtype) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out),
dtype) / jnp.sqrt(INTERMEDIATE)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE),
dtype) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2],
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
......@@ -69,20 +69,28 @@ class TestDistributedLayernormMLP:
return (x, gamma, k1, k2, b1, b2)
def layernorm_fp8_mlp_prim_func(self, x: jnp.ndarray, ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray,
bias_1: jnp.ndarray, bias_2: jnp.ndarray,
fp8_max: jnp.ndarray, fp8_metas_amax: jnp.ndarray,
fp8_metas_scale: jnp.ndarray, fp8_metas_scale_inv: jnp.ndarray,
def layernorm_fp8_mlp_prim_func(
self,
x: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1],
scale_list_1[1], amax_list_1[2], scale_list_1[2])
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1],
scale_list_2[1], amax_list_2[2], scale_list_2[2])
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
......@@ -95,9 +103,10 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
fused_layernorm_fp8_mlp(x, ln_scale, None,
[kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_pkg,
fused_layernorm_fp8_mlp(x,
ln_scale,
None, [kernel_1, kernel_2], [bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
......@@ -105,39 +114,49 @@ class TestDistributedLayernormMLP:
activation_type=activation_type,
use_bias=use_bias))
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",),
('gelu', 'linear')])
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config,
activation_type, use_bias,
input_shape, dtype):
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm'
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2,
FP8Helper.AMAX_HISTORY_LEN), jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
]
inputs = [x, gamma, k1, k2, b1, b2] = \
self.generate_inputs(input_shape, activation_type, use_bias, dtype)
inputs = [*inputs, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]
static_inputs = [layernorm_type,
activation_type,
use_bias]
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func,
argnums=range(len(inputs)))
# Single GPU
single_jitter = jax.jit(value_and_grad_func,
static_argnums=range(len(inputs),
len(static_inputs)+len(inputs)))
len(static_inputs) + len(inputs)))
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
......@@ -159,28 +178,38 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding,
None, None, None, None, None)
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding,
None, None, None, None, None))
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None)
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None,
None, None, None))
multi_jitter = jax.jit(
value_and_grad_func,
multi_jitter = jax.jit(value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs),
len(static_inputs)+len(multi_inputs)+1)) # +1 for multi_gpus
len(static_inputs) + len(multi_inputs) +
1)) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype)
for i in range(len(inputs)):
if multi_grads[i] is not None:
assert_allclose(multi_grads[i], single_grads[i], dtype=dtype,
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad,
s_grad,
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
else:
assert_allclose(multi_grads[i],
single_grads[i],
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias,
input_shape, dtype, use_fp8):
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype,
use_fp8):
batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm'
......@@ -201,7 +230,8 @@ class TestDistributedLayernormMLP:
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single, x,
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single,
x,
deterministic=True)
# Multi GPUs
......@@ -209,8 +239,7 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
......@@ -225,10 +254,10 @@ class TestDistributedLayernormMLP:
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name='mlp'
)
name='mlp')
params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded, x,
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded,
x,
deterministic=True)
# Make sure params values are the same
......@@ -238,25 +267,28 @@ class TestDistributedLayernormMLP:
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",),
('silu', 'linear'),
('gelu', 'gelu')])
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias,
input_shape, dtype):
self._test_layernorm_mlp(mesh_config, activation_type, use_bias, input_shape, dtype,
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",),
('gelu', 'linear'),
('gelu', 'gelu')])
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias,
input_shape, dtype):
self._test_layernorm_mlp(mesh_config, activation_type, use_bias, input_shape, dtype,
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True)
......@@ -45,90 +45,6 @@ class TestFP8Helper(unittest.TestCase):
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_len=3)
seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
num_of_gemm = 10
num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm
def select_amax(amaxes):
if FP8Helper.AMAX_COMPUTE_ALGO == AmaxComputeAlgo.MAX:
return jnp.max(amaxes, axis=-1, keepdims=True)
return amaxes[:, 0:1]
def get_fp8_scale(fp8_max, amax, scale):
fp8_max = np.array(fp8_max)
amax = np.array(amax)
scale = np.array(scale)
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
scale_meta_shape = (num_of_meta, 1)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=amax_meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array1),
jnp.ones(scale_meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=amax_meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, select_amax(fp8_amax_array2),
jnp.ones(scale_meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({
FP8Helper.FP8_COLLECTION_NAME: {
"test_update_fp8_metas1": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
},
"test_update_fp8_metas2": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
FP8Helper.FP8_SCALE_NAME: jnp.ones(scale_meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(scale_meta_shape)
}
}
})
updated_state = FP8Helper.update_fp8_metas(state)
state_array, _ = jax.tree_util.tree_flatten(updated_state)
meta_per_gemm = FP8Helper.NUM_META_PER_GEMM + 1
scale_shift = 2
scale_inv_shift = 3
assert_allclose(state_array[0 * meta_per_gemm + scale_shift], fp8_scale_array1)
assert_allclose(state_array[0 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array1)
assert_allclose(state_array[1 * meta_per_gemm + scale_shift], fp8_scale_array2)
assert_allclose(state_array[1 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array2)
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_generate_fp8_max_array(self):
num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2
def get_ref(format_for_test):
refer_list = []
for i in range(num_of_meta):
val = format_for_test.value.max_bwd \
if i % FP8Helper.NUM_META_PER_GEMM == FP8Helper.GRAD_META_IDX_PER_GEMM \
else format_for_test.value.max_fwd
refer_list.append([val])
return jnp.asarray(refer_list)
for fp8_format in FP8Format:
FP8Helper.initialize(fp8_format=fp8_format)
assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
original_val = 0.0
......
......@@ -260,7 +260,6 @@ class BaseRunner:
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
......
......@@ -17,7 +17,7 @@ from utils import assert_allclose
from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax import fp8_autocast, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
......@@ -89,8 +89,6 @@ class TestLayer:
def loss_and_grads(module, variables, *inputs):
grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
if FP8Helper.is_fp8_enabled():
wgrads = update_fp8_metas(wgrads)
return loss_val, wgrads, dgrad
def input_getter(self, shape, dtype):
......
......@@ -26,7 +26,7 @@ def _load_library():
_TE_JAX_LIB_CTYPES = _load_library()
from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
......@@ -45,7 +45,6 @@ __all__ = [
'NVTE_FP8_COLLECTION_NAME',
'fp8_autocast',
'update_collections',
'update_fp8_metas',
'get_delayed_scaling',
'MeshResource',
'MajorShardingType',
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence
from typing import List, Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp
......@@ -28,14 +28,11 @@ def type_safe_dot_general(
kernel = jnp.asarray(kernel, x.dtype)
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
amax_list = fp8_meta_pkg.amax_list
scale_list = fp8_meta_pkg.scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
contracting_dims)
return _fp8_dot(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims)
def quantize(x, q_dtype, scale):
......@@ -84,11 +81,11 @@ def get_precision_of_fp8_dot(enable_2xACC: bool):
return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, amax_list: List[jnp.ndarray],
scale_list: List[jnp.ndarray], fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
contracting_dims: Tuple[Sequence[int], Sequence[int]]):
output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
output, _ = _fp8_dot_fwd_rule(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype,
contracting_dims)
return output
......@@ -96,17 +93,16 @@ def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jn
def _fp8_dot_fwd_rule(
x,
kernel,
fp8_max,
amax,
scale,
scale_inv,
amax_list,
scale_list,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims):
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
......@@ -114,19 +110,19 @@ def _fp8_dot_fwd_rule(
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
assert x_shape_suf == kernel_shape_pre
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
fp8_dtype_list)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
......@@ -135,7 +131,7 @@ def _fp8_dot_fwd_rule(
(lhs_contracting_dims, rhs_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
ctx = (casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
return output, ctx
......@@ -143,15 +139,13 @@ def _fp8_dot_fwd_rule(
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # pylint: disable=unused-argument
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, \
updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
maybe_fp32_to_fm32 = ctx
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
......@@ -160,7 +154,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(x_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
......@@ -168,18 +162,22 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax)
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
amax_list[FP8MetaPackage.INPUT_IDX] = \
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax)
amax_list[FP8MetaPackage.WEIGHT_IDX] = \
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax)
amax_list[FP8MetaPackage.GRAD_IDX] = \
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
return dgrad, wgrad, fp8_max, amax, scale, scale_inv
return dgrad, wgrad, amax_list, scale_list
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)
......@@ -6,7 +6,6 @@ Wrapper module for Transformer related layers with FP8 support.
"""
import functools
import operator
import warnings
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp
......@@ -48,7 +47,9 @@ def _canonicalize_tuple(x):
def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_gamma):
if original_init is None:
if original_init is not None:
return original_init
if not zero_centered_gamma:
return nn.initializers.ones
return nn.initializers.zeros
......@@ -270,7 +271,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_axes: Tuple[str, ...] = ('embed',)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
sharding_type = None
def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
......@@ -292,8 +292,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray
Output tensors.
"""
warnings.warn("sharding_type of LayerNorm would be removed in the near feature",
DeprecationWarning)
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
......@@ -307,53 +305,43 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.epsilon)
class TransformerEngineBase(nn.Module):
class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods
"""
Base class of transformer engine
"""
@staticmethod
def get_fp8_metas(num_of_gemm: int) -> List[jnp.ndarray]:
def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage:
"""
Get the FP8 metas
Generate a set of FP8 meta for a GEMM.
"""
num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
axes = ('fp8_meta_axis', 'fp8_meta_history')
fp8_max = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_MAX_NAME,
FP8Helper.generate_fp8_max_array,
num_of_meta,
axes=axes)
fp8_metas_amax = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_AMAX_NAME,
jnp.zeros, (num_of_meta, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32,
axes=axes)
fp8_metas_scale = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_SCALE_NAME,
jnp.ones, (num_of_meta, 1),
input_name_post_fix = f"_i_{postfix}"
weight_name_post_fix = f"_w_{postfix}"
grad_name_post_fix = f"_g_{postfix}"
def generate_a_set(target_postfix):
amax = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_AMAX_NAME}{target_postfix}",
jnp.zeros, (FP8Helper.AMAX_HISTORY_LEN,),
jnp.float32,
axes=axes)
fp8_metas_scale_inv = nn_partitioning.variable_with_axes(FP8Helper.FP8_COLLECTION_NAME,
FP8Helper.FP8_SCALE_INV_NAME,
jnp.ones, (num_of_meta, 1),
axes=(None,))
scale = nn_partitioning.variable_with_axes(
FP8Helper.FP8_COLLECTION_NAME,
f"{FP8Helper.FP8_SCALE_NAME}{target_postfix}",
jnp.ones, (1,),
jnp.float32,
axes=axes)
axes=(None,))
return fp8_max.value, fp8_metas_amax.value, fp8_metas_scale.value, fp8_metas_scale_inv.value
return amax.value, scale.value
@staticmethod
def get_fp8_meta_package(num_of_gemm: int) -> FP8MetaPackage:
"""
Get the FP8 metas
"""
fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv = \
TransformerEngineBase.get_fp8_metas(num_of_gemm)
input_amax, input_scale = generate_a_set(input_name_post_fix)
weight_amax, weight_scale = generate_a_set(weight_name_post_fix)
grad_amax, grad_scale = generate_a_set(grad_name_post_fix)
return FP8MetaPackage(num_of_gemm, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax,
grad_scale)
class DenseGeneral(TransformerEngineBase):
......@@ -412,7 +400,6 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -434,8 +421,6 @@ class DenseGeneral(TransformerEngineBase):
outputs : jax.numpy.ndarray
Output tensors.
"""
warnings.warn("sharding_type of DenseGeneral would be removed in the near feature",
DeprecationWarning)
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
......@@ -464,14 +449,13 @@ class DenseGeneral(TransformerEngineBase):
bias = None
contract_ind = tuple(range(0, len(axis)))
fp8_gemm_pkg = None
fp8_meta_pkg = None
if FP8Helper.is_fp8_enabled():
fp8_gemm_pkg = \
TransformerEngineBase.get_fp8_meta_package(1)
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
y = type_safe_dot_general(inputs,
kernel,
fp8_meta_pkg=fp8_gemm_pkg,
fp8_meta_pkg=fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
......@@ -619,7 +603,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -646,8 +629,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
warnings.warn("sharding_type of LayerNormDenseGeneral would be removed in the near feature",
DeprecationWarning)
ln_output = None
......@@ -699,17 +680,16 @@ class LayerNormDenseGeneral(TransformerEngineBase):
contract_ind = tuple(range(0, len(axis)))
fp8_meta_package = None
fp8_meta_pkg = None
if FP8Helper.is_fp8_enabled():
fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(1)
fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
if fuse_layernorm:
z = layernorm_fp8_dot(y,
kernel,
scale,
ln_bias,
fp8_meta_package,
fp8_meta_pkg,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
......@@ -719,7 +699,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = type_safe_dot_general(y,
kernel,
fp8_meta_pkg=fp8_meta_package,
fp8_meta_pkg=fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
......@@ -906,7 +886,6 @@ class LayerNormMLP(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
major_sharding_type = None
def __post_init__(self):
if self.kernel_init is None:
......@@ -935,31 +914,22 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
warnings.warn("major_sharding_type of LayerNormMLP would be removed in the near feature",
DeprecationWarning)
ln_output = None
fuse_layernorm = FP8Helper.is_fp8_enabled(
) and not self.return_layernorm_output and self.enable_layernorm
gated_act_pool = [('gelu', 'linear'),
('silu', 'linear'),
('relu', 'linear'),
('quick_gelu', 'linear'),
('squared_relu', 'linear')]
act_pool = [('gelu',),
('silu',),
('relu',),
('quick_gelu',),
('squared_relu',)]
gated_act_pool = [('gelu', 'linear'), ('silu', 'linear'), ('relu', 'linear'),
('quick_gelu', 'linear'), ('squared_relu', 'linear')]
act_pool = [('gelu',), ('silu',), ('relu',), ('quick_gelu',), ('squared_relu',)]
normalized_acts = []
for act in self.activations:
if not isinstance(act, str):
return False
normalized_acts.append(act.lower())
normalized_acts = tuple(reversed(normalized_acts)
if normalized_acts[0] == 'linear' else normalized_acts)
normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == 'linear' else normalized_acts)
is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
......@@ -1001,11 +971,11 @@ class LayerNormMLP(TransformerEngineBase):
kernels.append(self.kernel_init(init_key, *init_args))
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32)
num_of_gemm = 2
fp8_meta_package = None
wi_fp8_meta_pkg = None
wo_fp8_meta_pkg = None
if FP8Helper.is_fp8_enabled():
fp8_meta_package = \
TransformerEngineBase.get_fp8_meta_package(num_of_gemm)
wi_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("0")
wo_fp8_meta_pkg = TransformerEngineBase.generate_fp8_meta_set("1")
num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis)
......@@ -1063,7 +1033,7 @@ class LayerNormMLP(TransformerEngineBase):
out = fused_layernorm_fp8_mlp(y,
scale,
ln_bias, [kernel_1, kernel_2], [bias_1, bias_2],
fp8_meta_package,
[wi_fp8_meta_pkg, wo_fp8_meta_pkg],
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
......@@ -1072,18 +1042,17 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type = normalized_acts,
use_bias = self.use_bias)
activation_type=normalized_acts,
use_bias=self.use_bias)
else: # not use_fused_ln_geglu_mlp
# DenseGeneral 1
gemm1_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(0)
if fuse_layernorm:
x = layernorm_fp8_dot(y,
kernel_1,
scale,
ln_bias,
gemm1_fp8_meta_package,
wi_fp8_meta_pkg,
self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
......@@ -1093,7 +1062,7 @@ class LayerNormMLP(TransformerEngineBase):
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
x = type_safe_dot_general(y,
kernel_1,
fp8_meta_pkg=gemm1_fp8_meta_package,
fp8_meta_pkg=wi_fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
......@@ -1159,12 +1128,9 @@ class LayerNormMLP(TransformerEngineBase):
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
# DenseGeneral 2
gemm2_fp8_meta_package = None if fp8_meta_package is None \
else fp8_meta_package.get_package_by_gemm_idx(1)
out = type_safe_dot_general(z,
kernel_2,
fp8_meta_pkg=gemm2_fp8_meta_package,
fp8_meta_pkg=wo_fp8_meta_pkg,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
......
......@@ -6,7 +6,8 @@ Helper module for fp8 meta management
"""
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Optional, Tuple, Union
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
......@@ -85,72 +86,67 @@ class FP8MetaPackage:
A container that contains all required meta data for FP8
"""
NUM_OF_META: int = 3
INPUT_IDX: int = 0
WEIGHT_IDX: int = 1
GRAD_IDX: int = 2
def __init__(
self,
num_of_gemm: int,
fp8_max: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
input_amax: jnp.ndarray,
input_scale: jnp.ndarray,
weight_amax: jnp.ndarray,
weight_scale: jnp.ndarray,
grad_amax: jnp.ndarray,
grad_scale: jnp.ndarray,
) -> None:
total_num_of_meta = num_of_gemm * FP8Helper.NUM_META_PER_GEMM
self._num_of_gemm = num_of_gemm
assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta
self._amax = amax
assert scale.shape[0] == total_num_of_meta
self._scale = scale
assert scale_inv.shape[0] == total_num_of_meta
self._scale_inv = scale_inv
@property
def num_of_gemm(self) -> int:
"""
num_of_gemm of this package
"""
return self._num_of_gemm
self._amax_list = [None] * FP8MetaPackage.NUM_OF_META
self._scale_list = [None] * FP8MetaPackage.NUM_OF_META
@property
def fp8_max(self) -> jnp.ndarray:
"""
fp8_max of this package
"""
return self._fp8_max
self._amax_list[FP8MetaPackage.INPUT_IDX] = input_amax
self._scale_list[FP8MetaPackage.INPUT_IDX] = input_scale
self._amax_list[FP8MetaPackage.WEIGHT_IDX] = weight_amax
self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale
self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax
self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale
@property
def amax(self) -> jnp.ndarray:
def amax_list(self) -> List[jnp.ndarray]:
"""
amax of this package
Get the amax list of this package.
"""
return self._amax
return self._amax_list
@property
def scale(self) -> jnp.ndarray:
def scale_list(self) -> List[jnp.ndarray]:
"""
scale of this package
Get the scale list of this package.
"""
return self._scale
return self._scale_list
@property
def scale_inv(self) -> jnp.ndarray:
@staticmethod
def update_amax_list(amax_list: List[jnp.ndarray]) -> jnp.ndarray:
"""
scale_inv of this package
Update the amax history list
"""
return self._scale_inv
updated_amax_list = [FP8Helper.update_amax_history(amax) for amax in amax_list]
return updated_amax_list
def get_package_by_gemm_idx(self, gemm_idx):
@staticmethod
def update_fp8_scale(
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray],
fp8_dtype_list: List[DType]) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
"""
Get a sub package by gemm_idx
Get update scale and scale_inv list
"""
assert self.num_of_gemm > gemm_idx
meta_start_idx = gemm_idx * FP8Helper.NUM_META_PER_GEMM
meta_end_idx = (gemm_idx + 1) * FP8Helper.NUM_META_PER_GEMM
return FP8MetaPackage(1, self.fp8_max[meta_start_idx:meta_end_idx],
self.amax[meta_start_idx:meta_end_idx],
self.scale[meta_start_idx:meta_end_idx],
self.scale_inv[meta_start_idx:meta_end_idx])
update_scale_list = []
update_scale_inv_list = []
for amax, scale, fp8_dtype in zip(amax_list, scale_list, fp8_dtype_list):
upadted_scale, updated_scale_inv = FP8Helper.update_fp8_scale(amax, scale, fp8_dtype)
update_scale_list.append(upadted_scale)
update_scale_inv_list.append(updated_scale_inv)
return update_scale_list, update_scale_inv_list
class AmaxComputeAlgo(Enum):
......@@ -159,7 +155,7 @@ class AmaxComputeAlgo(Enum):
MOST_RECENT = "most_recent"
NVTE_FP8_COLLECTION_NAME = "fp8_meta_collection"
NVTE_FP8_COLLECTION_NAME = "fp8_metas"
class FP8Helper:
......@@ -173,15 +169,9 @@ class FP8Helper:
BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1]
AMAX_HISTORY_LEN: int = 1024
AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX
NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
FP8_MAX_NAME: str = "fp8_max"
FP8_AMAX_NAME: str = "amax"
FP8_SCALE_NAME: str = "scale"
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = True
FP8_2X_ACC_WGRAD: bool = True
......@@ -241,77 +231,6 @@ class FP8Helper:
new_coll = new_coll.unfreeze()
return new_coll
@staticmethod
def update_fp8_metas(state: Collection) -> Collection:
"""
Update the FP8 metas
"""
assert isinstance(state, (dict, FrozenDict))
if FP8Helper.FP8_COLLECTION_NAME in state:
frozen_state = FrozenDict(state) if not isinstance(state, FrozenDict) else state
others, fp8_metas = frozen_state.pop(FP8Helper.FP8_COLLECTION_NAME)
fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
new_state = FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})
if not isinstance(state, FrozenDict):
new_state = new_state.unfreeze()
return new_state
return state
@staticmethod
def generate_fp8_max_array(num_of_meta):
"""
Generate the FP8 max array
"""
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = jnp.finfo(FP8Helper.FWD_DTYPE).max
fp8_max_bwd = jnp.finfo(FP8Helper.BWD_DTYPE).max
fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
else fp8_max_fwd
fp8_max_per_gemm.append([val])
fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)
@staticmethod
def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int, int, int]:
"""
Obtain the index about FP8 metas by the given GEMM index.
"""
input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
return input_idx, kernel_idx, grad_idx
@staticmethod
@jax.jit
def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
for i in range(num_of_gemm):
# flattern array is ordered in alphabetical order of collection names
fp8_max_idx = i * num_of_meta_with_max
fp8_amax_idx = fp8_max_idx + 1
fp8_scale_idx = fp8_amax_idx + 1
fp8_scale_inv_idx = fp8_scale_idx + 1
fp8_max = fp8_meta_arrays[fp8_max_idx]
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(fp8_meta_arrays[fp8_amax_idx], axis=-1, keepdims=True)
else:
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
scale = fp8_meta_arrays[fp8_scale_idx]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
fp8_meta_arrays[fp8_scale_idx] = sf
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@staticmethod
def generate_fp8_meta_dtype_converter_pair(*args):
"""
......@@ -319,7 +238,7 @@ class FP8Helper:
"""
def identical_fun(*metas):
return metas
return list(metas)
def fm32_to_fp32_fun(*metas):
for meta in metas:
......@@ -349,25 +268,27 @@ class FP8Helper:
return partial_identical_fun, partial_identical_fun
@staticmethod
@jax.jit
def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray:
"""
Update the amax history
"""
updated_amax = jnp.roll(amax, -1, -1)
updated_amax = updated_amax.at[..., 0].set(0)
updated_amax = updated_amax.at[0].set(0)
return updated_amax
@staticmethod
@jax.jit
def update_fp8_scale(fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray) -> jnp.ndarray:
@partial(jax.jit, static_argnums=(2,))
def update_fp8_scale(amax: jnp.ndarray, scale: jnp.ndarray, fp8_dtype: DType) -> jnp.ndarray:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
fp8_max = jnp.astype(jnp.finfo(fp8_dtype).max, jnp.float32)
if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax, axis=-1, keepdims=True)
else:
amax = amax[..., 0:1]
amax = amax[0:1]
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
......@@ -475,32 +396,6 @@ def update_collections(new: Collection, original: Collection) -> FrozenDict:
return FP8Helper.update_collections(new, original)
def update_fp8_metas(state: Collection) -> Collection:
r"""
Calculate new fp8 scales and its inverse via the followed formula
.. code-block:: python
sf = (fp8_max / amax) / (2 ^ margin)
sf = sf if amax > 0.0, else original_scale
updated_scale = sf if isfinite(amax), else original_scale)
updated_scale_inv = 1/updated_scale
Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters
----------
state: Collection
A collection that includes FP8 metas.
Returns
-------
outputs : Collection
The collection with updated FP8 metas.
"""
return FP8Helper.update_fp8_metas(state)
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
......
......@@ -4,7 +4,7 @@
"""JAX layernorm modules"""
from functools import partial
from typing import Tuple
from typing import List, Tuple
import jax
import jax.numpy as jnp
......@@ -116,25 +116,23 @@ def layernorm_fp8_dot(
"""
Layernorm + FP8 GEMM
"""
fp8_max = fp8_meta_pkg.fp8_max
amax = fp8_meta_pkg.amax
scale = fp8_meta_pkg.scale
scale_inv = fp8_meta_pkg.scale_inv
amax_list = fp8_meta_pkg.amax_list
scale_list = fp8_meta_pkg.scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
output = _layernorm_fp8_dot(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
layernorm_type, fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
output = _layernorm_fp8_dot(x, kernel, gamma, beta, amax_list, scale_list, layernorm_type,
fwd_dtype, bwd_dtype, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_input_axes)
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14))
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _layernorm_fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, layernorm_type: str, fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, zero_centered_gamma: bool, epsilon: float,
amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray],
layernorm_type: str, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
zero_centered_gamma: bool, epsilon: float,
layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...]):
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, fp8_max, amax, scale, scale_inv,
output, _ = _layernorm_fp8_dot_fwd_rule(x, kernel, gamma, beta, amax_list, scale_list,
layernorm_type, fwd_dtype, bwd_dtype,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_input_axes)
......@@ -146,10 +144,8 @@ def _layernorm_fp8_dot_fwd_rule(
kernel,
gamma,
beta,
fp8_max,
amax,
scale,
scale_inv,
amax_list,
scale_list,
layernorm_type,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
......@@ -163,17 +159,18 @@ def _layernorm_fp8_dot_fwd_rule(
assert x.shape[-1] == kernel.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
amax_list = maybe_fm32_to_fp32(*amax_list)
scale_list = maybe_fm32_to_fp32(*scale_list)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
fp8_dtype_list)
amax_list = FP8MetaPackage.update_amax_list(amax_list)
gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm_x_idx, 0:1]
x_scale = scale[gemm_x_idx]
x_scale_inv = scale_inv[gemm_x_idx]
x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1]
x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
......@@ -202,9 +199,9 @@ def _layernorm_fp8_dot_fwd_rule(
assert x.shape == ln_out.shape
kernel_amax = amax[gemm_kernel_idx, 0:1]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1]
kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
# Kernel in (hidden_in, hidden_out...)
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
......@@ -219,7 +216,7 @@ def _layernorm_fp8_dot_fwd_rule(
(x_contracting_dims, k_contracting_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
ctx = (ln_out, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims,
k_contracting_dims, maybe_fp32_to_fm32)
......@@ -236,18 +233,16 @@ def _layernorm_fp8_dot_bwd_rule(
dot_input_axes, # pylint: disable=unused-argument
ctx,
grad):
ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \
ln_out_, casted_kernel, amax_list, scale_list, scale_inv_list, \
updated_x_amax, updated_kernel_amax, \
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx, 0:1]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
......@@ -255,7 +250,7 @@ def _layernorm_fp8_dot_bwd_rule(
xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
x_scale_inv = scale_inv[gemm_x_idx]
x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
wgrad = fp8_dot_impl(ln_out_t, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
(xt_constracting_dim, gt_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
......@@ -263,7 +258,7 @@ def _layernorm_fp8_dot_bwd_rule(
g_for_dgrad_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(k_contracting_dims), grad.ndim))
k_constracting_dim = tuple(range(len(k_contracting_dims), len(kernel_shape)))
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
(g_for_dgrad_constracting_dim, k_constracting_dim),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
......@@ -283,15 +278,19 @@ def _layernorm_fp8_dot_bwd_rule(
dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
amax_list[FP8MetaPackage.INPUT_IDX] = \
amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
amax_list[FP8MetaPackage.WEIGHT_IDX] = \
amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0])
amax_list[FP8MetaPackage.GRAD_IDX] = \
amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
amax_list = maybe_fp32_to_fm32(*amax_list)
scale_list = maybe_fp32_to_fm32(*scale_list)
return dx, wgrad, \
dgamma, dbeta, \
fp8_max, amax, scale, scale_inv
amax_list, scale_list
_layernorm_fp8_dot.defvjp(_layernorm_fp8_dot_fwd_rule, _layernorm_fp8_dot_bwd_rule)
......@@ -61,7 +61,7 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray,
beta: jnp.ndarray,
kernels: List[jnp.ndarray],
biases: List[jnp.ndarray],
fp8_gemm_pkg: FP8MetaPackage,
fp8_meta_pkgs: List[FP8MetaPackage],
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
......@@ -77,16 +77,16 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray,
"""
assert len(kernels) == 2
assert fp8_gemm_pkg.num_of_gemm == len(kernels)
assert len(fp8_meta_pkgs) == len(kernels)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
bias_2 = biases[1]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
amax_list_1 = fp8_meta_pkgs[0].amax_list
amax_list_2 = fp8_meta_pkgs[1].amax_list
scale_list_1 = fp8_meta_pkgs[0].scale_list
scale_list_2 = fp8_meta_pkgs[1].scale_list
fwd_dtype = FP8Helper.FWD_DTYPE
bwd_dtype = FP8Helper.BWD_DTYPE
......@@ -97,29 +97,31 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray,
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max,
amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type,
zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name,
ffn2_ckpt_name, activation_type, use_bias)
output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2,
amax_list_1, amax_list_2, scale_list_1, scale_list_2,
fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma,
epsilon, layernorm_input_axes, dot_1_input_axes,
dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output
@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))
def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray,
bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype,
bias_2: jnp.ndarray, amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray], fwd_dtype: jnp.dtype,
bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool,
epsilon: float, layernorm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], use_bias: bool):
output, _ = _fused_layernorm_fp8_mlp_fwd_rule(
x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv,
fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes,
dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type,
use_bias)
x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, amax_list_1, amax_list_2, scale_list_1,
scale_list_2, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon,
layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name,
activation_type, use_bias)
return output
......@@ -131,10 +133,10 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
kernel_2,
bias_1,
bias_2,
fp8_max,
amax,
scale,
scale_inv,
amax_list_1,
amax_list_2,
scale_list_1,
scale_list_2,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
layernorm_type,
......@@ -162,17 +164,24 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert kernel_1.shape[-1] == kernel_2.shape[0]
maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)
scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
amax = FP8Helper.update_amax_history(amax)
gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
x_amax = amax[gemm1_x_idx, 0:1]
x_scale = scale[gemm1_x_idx]
x_scale_inv = scale_inv[gemm1_x_idx]
FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list_1, *scale_list_1,
*amax_list_2, *scale_list_2)
amax_list_1 = maybe_fm32_to_fp32(*amax_list_1)
scale_list_1 = maybe_fm32_to_fp32(*scale_list_1)
amax_list_2 = maybe_fm32_to_fp32(*amax_list_2)
scale_list_2 = maybe_fm32_to_fp32(*scale_list_2)
fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
scale_list_1, scale_inv_list_1 = FP8MetaPackage.update_fp8_scale(amax_list_1, scale_list_1,
fp8_dtype_list)
amax_list_1 = FP8MetaPackage.update_amax_list(amax_list_1)
scale_list_2, scale_inv_list_2 = FP8MetaPackage.update_fp8_scale(amax_list_2, scale_list_2,
fp8_dtype_list)
amax_list_2 = FP8MetaPackage.update_amax_list(amax_list_2)
x_amax = amax_list_1[FP8MetaPackage.INPUT_IDX][0:1]
x_scale = scale_list_1[FP8MetaPackage.INPUT_IDX]
x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
......@@ -201,9 +210,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
assert x.shape == ln_out.shape
kernel_1_amax = amax[gemm1_kernel_idx, 0:1]
kernel_1_scale = scale[gemm1_kernel_idx]
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
kernel_1_amax = amax_list_1[FP8MetaPackage.WEIGHT_IDX][0:1]
kernel_1_scale = scale_list_1[FP8MetaPackage.WEIGHT_IDX]
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
......@@ -224,12 +233,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
bias_1_shape = None
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1)
activation_lu_out_amax = amax[gemm2_x_idx, 0:1]
activation_lu_out_scale = scale[gemm2_x_idx]
activation_lu_out_scale_inv = scale_inv[gemm2_x_idx]
activation_lu_out_amax = amax_list_2[FP8MetaPackage.INPUT_IDX][0:1]
activation_lu_out_scale = scale_list_2[FP8MetaPackage.INPUT_IDX]
activation_lu_out_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \
......@@ -239,8 +245,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
casted_activation_lu_out, dot_2_input_axes)
kernel_2_scale = scale[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_scale = scale_list_2[FP8MetaPackage.WEIGHT_IDX]
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale)
......@@ -261,9 +267,10 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1,
casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax,
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax,
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32)
casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, scale_inv_list_1,
scale_inv_list_2, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax,
updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape,
maybe_fp32_to_fm32)
return dot_2_output, ctx
......@@ -284,15 +291,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
ctx,
grad):
x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \
casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \
casted_kernel_1, casted_kernel_2, amax_list_1, amax_list_2, scale_list_1, scale_list_2, \
scale_inv_list_1, scale_inv_list_2, updated_x_amax, \
updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \
x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx
gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1)
grad_amax = amax[gemm2_grad_idx, 0:1]
grad_scale = scale[gemm2_grad_idx]
grad_scale_inv = scale_inv[gemm2_grad_idx]
grad_amax = amax_list_2[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list_2[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list_2[FP8MetaPackage.GRAD_IDX]
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
......@@ -316,24 +322,22 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
transpose_axis_boundary=-1)
# (hidden, batch...,) x (hidden, batch...)
gemm2_x_scale_inv = scale_inv[gemm2_x_idx]
gemm2_x_scale_inv = scale_inv_list_2[FP8MetaPackage.INPUT_IDX]
wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv,
grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv = scale_inv[gemm2_kernel_idx]
kernel_2_scale_inv = scale_inv_list_2[FP8MetaPackage.WEIGHT_IDX]
dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv,
grad.dtype, (x_contracting_dims, (1,)),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0)
dactivation_lu_amax = amax[gemm1_grad_idx, 0:1]
dactivation_lu_scale = scale[gemm1_grad_idx]
dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx]
dactivation_lu_amax = amax_list_1[FP8MetaPackage.GRAD_IDX][0:1]
dactivation_lu_scale = scale_list_1[FP8MetaPackage.GRAD_IDX]
dactivation_lu_scale_inv = scale_inv_list_1[FP8MetaPackage.GRAD_IDX]
if len(activation_type) > 1: # if gated
if use_bias:
......@@ -390,15 +394,15 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv[gemm1_x_idx]
gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims)
wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv,
dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2),
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
x_contracting_dims = ((min(x_contracting_dims),) + tuple(
i + 1 for i in x_contracting_dims), (1,2))
kernel_1_scale_inv = scale_inv[gemm1_kernel_idx]
x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims),
(1, 2))
kernel_1_scale_inv = scale_inv_list_1[FP8MetaPackage.WEIGHT_IDX]
dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv,
kernel_1_scale_inv, grad.dtype, x_contracting_dims,
get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
......@@ -419,17 +423,26 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0])
amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0])
amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0])
amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0])
amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax)
amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0])
fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
amax_list_1[FP8MetaPackage.INPUT_IDX] = \
amax_list_1[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0])
amax_list_1[FP8MetaPackage.WEIGHT_IDX] = \
amax_list_1[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_1_amax[0])
amax_list_1[FP8MetaPackage.GRAD_IDX] = \
amax_list_1[FP8MetaPackage.GRAD_IDX].at[0].set(updated_dactivation_lu_amax[0])
amax_list_2[FP8MetaPackage.INPUT_IDX] = \
amax_list_2[FP8MetaPackage.INPUT_IDX].at[0].set(updated_activation_lu_amax[0])
amax_list_2[FP8MetaPackage.WEIGHT_IDX] = \
amax_list_2[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_2_amax)
amax_list_2[FP8MetaPackage.GRAD_IDX] = \
amax_list_2[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
amax_list_1 = maybe_fp32_to_fm32(*amax_list_1)
scale_list_1 = maybe_fp32_to_fm32(*scale_list_1)
amax_list_2 = maybe_fp32_to_fm32(*amax_list_2)
scale_list_2 = maybe_fp32_to_fm32(*scale_list_2)
return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \
fp8_max, amax, scale, scale_inv
amax_list_1, amax_list_2, scale_list_1, scale_list_2
_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule,
......
......@@ -19,7 +19,6 @@ from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
from ..sharding import MajorShardingType, ShardingType
def _generate_ln_scale_init(scale_init):
......@@ -76,7 +75,6 @@ class LayerNorm(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
......@@ -106,7 +104,6 @@ class FusedSoftmax(TransformerEngineBaseLayer):
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
......@@ -136,7 +133,6 @@ class Linear(TransformerEngineBaseLayer):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
......@@ -187,7 +183,6 @@ class LayerNormLinear(TransformerEngineBaseLayer):
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
sharding_type: ShardingType = ShardingType.SINGLE
def setup(self) -> None:
"""setup"""
......@@ -253,7 +248,6 @@ class LayerNormMLP(TransformerEngineBaseLayer):
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE
def setup(self) -> None:
"""setup"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment