Unverified Commit be1f647c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX-Q] Distributed MXFP8 flax layer tests (#1643)



MXFP8 flax layer tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 1bbeab1c
......@@ -267,9 +267,18 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
)
params_single = ln_mlp_single.init(init_rngs, x)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
......@@ -298,7 +307,7 @@ class TestDistributedLayernormMLP:
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
......@@ -318,20 +327,22 @@ class TestDistributedLayernormMLP:
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
)
......@@ -1088,6 +1088,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
if self.kernel_axes_1 is not None:
kernel_1 = with_sharding_constraint_by_logical_axes(
kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:]
)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
......@@ -1105,7 +1109,8 @@ class LayerNormMLP(TransformerEngineBase):
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
if self.kernel_axes_2 is not None:
kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2)
contract_ind = tuple(range(0, len(axis)))
if self.use_bias:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment