test_distributed_layernorm_mlp.py 19.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Sequence, Union, Optional
5
6
7
8
9
10
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11
12
13
14
15
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
)
18

19
20
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
21
22
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
23
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
24
from transformer_engine.jax.sharding import (
25
26
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
27
    BATCH_AXES,
28
29
30
31
32
33
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
34
)
35
from transformer_engine.jax.sharding import MeshResource
36
from transformer_engine.jax.quantize import QuantizerFactory
37
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
38

39
40

is_fp8_supported, reason = is_fp8_available()
41
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
42
43
44
45

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
46
    SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
47
48
49
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

50
DTYPES = [jnp.bfloat16, jnp.float16]
51
INPUT_SHAPE = [[4, 128, 256]]  # [batch, seqlen, hidden_in]
52
53
54
55

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
56
57
58
59
60
61
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
62
INTERMEDIATE = 256
63

64

65
66
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
67
    configs = []
68
69
70
71
72
73
74
75
76
77
78
79
80
    if is_devices_enough(4):
        configs.append(
            pytest.param(
                [
                    4,
                    (2, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp2",
            )
        )

81
82
    if is_devices_enough(2):
        configs.append(
83
84
85
86
87
88
89
90
91
            pytest.param(
                [
                    2,
                    (1, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp1_tpsp2",
            )
92
        )
93
        configs.append(
94
95
96
97
98
99
100
101
102
            pytest.param(
                [
                    2,
                    (2, 1),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp1",
            ),
103
        )
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
118
        k1 = jax.random.normal(
119
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
120
121
122
123
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
124
        if use_bias:
125
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
126
127
128
129
130
131
132
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

133
134
135
136
137
138
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
139
140
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
141
        layernorm_type: str = "rmsnorm",
142
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
143
144
145
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

146
147
148
149
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
150
151
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
152
153
        else:
            layernorm_input_axes = None
154
155
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
156

157
158
        quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)

159
160
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
161
            layernorm_mlp(
162
163
164
165
166
167
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
168
                norm_input_axes=layernorm_input_axes,
169
170
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
171
172
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
173
                activation_type=activation_type,
174
                quantizer_sets=quantizer_sets,
175
176
            )
        )
177

178
    def _test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
179
180
181
182
183
184
185
186
187
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        use_shardy,
        with_jax_gemm,
188
    ):
189
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
190
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
191
        layernorm_type = "rmsnorm"
192

193
194
195
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
196
        static_inputs = [layernorm_type, activation_type]
197

Alp Dener's avatar
Alp Dener committed
198
199
200
        with use_jax_gemm(enabled=with_jax_gemm):
            value_and_grad_func = jax.value_and_grad(
                self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
201
202
            )

Alp Dener's avatar
Alp Dener committed
203
            # Single GPU
204
205
206
            with fp8_autocast(
                enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
            ):
Alp Dener's avatar
Alp Dener committed
207
208
209
210
211
212
213
214
215
216
                single_jitter = jax.jit(
                    value_and_grad_func,
                    static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
                )
                single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

            # Multi GPUs
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
217
                enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
Alp Dener's avatar
Alp Dener committed
218
            ):
219
220
                k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
                k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
Alp Dener's avatar
Alp Dener committed
221
222
223
                k1_ = jax.device_put(k1, k1_sharding)
                k2_ = jax.device_put(k2, k2_sharding)
                if use_bias:
224
                    b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
Alp Dener's avatar
Alp Dener committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                    b1_ = jax.device_put(b1, b1_sharding)
                else:
                    b1_sharding = b1_ = None
                multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

                # Position ref for sharding pspec lists
                #   x, gamma, k1, k2, b1,
                #   b2
                in_shardings = (
                    None,
                    None,
                    k1_sharding,
                    k2_sharding,
                    b1_sharding,
                    None,
                )
                out_shardings = (
                    None,
                    (None, None, k1_sharding, k2_sharding, b1_sharding, None),
                )

                multi_jitter = jax.jit(
                    value_and_grad_func,
                    in_shardings=in_shardings,
                    out_shardings=out_shardings,
                    static_argnums=range(
                        len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
                    ),
                )  # +1 for multi_gpus

                multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
        bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
259

260
        assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
261

262
263
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
264
265
266
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
267
                        assert_allclose(
Alp Dener's avatar
Alp Dener committed
268
269
270
271
                            m_grad,
                            s_grad,
                            dtype=bwd_test_type,
                            err_msg=f"multi_grads[{i}] is not close",
272
                        )
273
                else:
274
275
276
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
Alp Dener's avatar
Alp Dener committed
277
                        dtype=bwd_test_type,
278
279
280
                        err_msg=f"multi_grads[{i}] is not close",
                    )

281
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
282
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
283
284
285
286
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
287
    @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
288
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
289
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
290
291
292
293
294
295
296
297
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
298
299
300
301
302
303
304
305
306
    ):
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            fp8_recipe,
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
307
            with_jax_gemm=with_jax_gemm,
308
309
310
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
311
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
312
313
314
315
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
316
    @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
317
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
318
    def test_layernorm_mlp_grad_shardy(
Alp Dener's avatar
Alp Dener committed
319
320
321
322
323
324
325
326
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
327
    ):
Alp Dener's avatar
Alp Dener committed
328
329
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
330
331
332
333
334
335
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
Alp Dener's avatar
Alp Dener committed
336
            fp8_recipe=fp8_recipe,
337
            use_shardy=True,
Alp Dener's avatar
Alp Dener committed
338
            with_jax_gemm=with_jax_gemm,
339
340
        )

341
    def _test_layernorm_mlp(
342
343
344
345
346
347
348
349
350
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        use_fp8,
        fp8_recipe,
        use_shardy,
Alp Dener's avatar
Alp Dener committed
351
        with_jax_gemm,
352
    ):
353
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
354
        batch, seqlen, hidden_in = input_shape
355
        layernorm_type = "rmsnorm"
356
357
358
359
360

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
361
        init_rngs = {"params": subkeys[1]}
362

Alp Dener's avatar
Alp Dener committed
363
364
        with use_jax_gemm(enabled=with_jax_gemm):
            # Single GPUs
365
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
Alp Dener's avatar
Alp Dener committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                ln_mlp_single = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    use_bias=use_bias,
                )
                params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
                mlp_out_single, ln_out_single = ln_mlp_single.apply(
                    params_single, x, deterministic=True
                )

            # Multi GPUs
            device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
                enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
            ):
                ln_mlp_sharded = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    scale_axes=LN_SCALE_AXES,
                    ln_bias_axes=LN_BIAS_AXES,
                    kernel_axes_1=KERNEL_1_AXES,
                    kernel_axes_2=KERNEL_2_AXES,
                    use_bias=use_bias,
                    bias_axes_1=BIAS_1_AXES,
                    bias_axes_2=BIAS_2_AXES,
                    layernorm_input_axes=LAYERNORM_INPUT_AXES,
                    dot_1_input_axes=DOT_1_INPUT_AXES,
                    dot_2_input_axes=DOT_2_INPUT_AXES,
                    name="mlp",
                )
                params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
                mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                    params_sharded, x, deterministic=True
                )
404
405

        # Make sure params values are the same
406
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
407
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
408

409
        # TODO(Phuong): check if these tols updates are still needed
410
411
412
413
414
415
416
417
418
419
420
421
422
        atol = None
        rtol = None
        l40_tolerance_update = (
            get_min_device_compute_capability() == 89
            and fp8_recipe == recipe.DelayedScaling()
            and use_fp8
            and dtype == jnp.float16
            and activation_type == ("gelu",)
        )
        if l40_tolerance_update:
            atol = 0.04
            rtol = 11

423
424
425
426
427
428
429
430
431
432
        # JAX's FP8 GEMM, jax.lax.dot_general, now uses the
        # Triton backend by default. The error of
        # the Triton FP8 gemm has been verified to be less than or equal
        # to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
        # However, Triton can auto-tune a different kernel for the single GPU
        # and multi-GPU run in this test, meaning the diff between single GPU
        # and multi-GPU can be larger in some cases, even though both are
        # within tolerance to the float32 ground truth.
        jax_triton_gemm_precision_tolerance_update = (
            with_jax_gemm
433
434
435
436
            and fp8_recipe is not None
            and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling())
            and dtype in (jnp.bfloat16, jnp.float16)
            and activation_type == ("gelu", "linear"),
437
438
439
440
441
        )
        if jax_triton_gemm_precision_tolerance_update:
            atol = 0.08
            rtol = 15

442
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
443

444
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
445
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
446
447
448
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
449
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
450
    def test_layernorm_mlp_layer(
Alp Dener's avatar
Alp Dener committed
451
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
452
    ):
453
        self._test_layernorm_mlp(
454
455
456
457
458
459
460
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
Alp Dener's avatar
Alp Dener committed
461
462
            use_shardy=False,
            with_jax_gemm=with_jax_gemm,
463
        )
464

465
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
466
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
467
468
469
470
471
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
472
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
473
    def test_layernorm_mlp_layer_fp8(
Alp Dener's avatar
Alp Dener committed
474
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
475
476
477
478
479
480
481
482
483
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
484
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
485
486
487
488
            with_jax_gemm=with_jax_gemm,
        )

    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
489
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
510
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_fp8_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
    ):
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
532
        )