test_distributed_layernorm_mlp.py 20.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import re
5
from typing import Callable, Sequence, Union, Optional
6
7
8
9
10
11
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
12
13
14
15
16
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
17
    use_jax_gemm,
18
)
19

20
from transformer_engine.common import recipe
21
22
23
24
25
from transformer_engine.jax.quantize import (
    is_fp8_available,
    ScalingMode,
    get_quantize_config_with_recipe,
)
26
from transformer_engine.jax import autocast
27
from transformer_engine.jax.flax import LayerNormMLP
28
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
29
from transformer_engine.jax.sharding import (
30
31
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
32
    BATCH_AXES,
33
34
35
36
37
38
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
39
)
40
from transformer_engine.jax.sharding import MeshResource
41
42
43
44
45
from transformer_engine.jax.quantize import (
    QuantizerFactory,
    get_supported_quantization_recipes,
    is_scaling_mode_supported,
)
46
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
47

48

49
50
51
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
52

53
54
SUPPORTED_RECIPES = get_supported_quantization_recipes()
SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES]
55

56
DTYPES = [jnp.bfloat16, jnp.float16]
57
INPUT_SHAPE = [[4, 128, 256]]  # [batch, seqlen, hidden_in]
58
59
60
61

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
62
63
64
65
66
67
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
68
INTERMEDIATE = 256
69

70

71
72
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
73
    configs = []
74
75
76
77
78
79
80
81
82
83
84
85
86
    if is_devices_enough(4):
        configs.append(
            pytest.param(
                [
                    4,
                    (2, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp2",
            )
        )

87
88
    if is_devices_enough(2):
        configs.append(
89
90
91
92
93
94
95
96
97
            pytest.param(
                [
                    2,
                    (1, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp1_tpsp2",
            )
98
        )
99
        configs.append(
100
101
102
103
104
105
106
107
108
            pytest.param(
                [
                    2,
                    (2, 1),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp1",
            ),
109
        )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
124
        k1 = jax.random.normal(
125
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
126
127
128
129
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
130
        if use_bias:
131
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
132
133
134
135
136
137
138
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

139
140
141
142
143
144
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
145
146
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
147
        layernorm_type: str = "rmsnorm",
148
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
149
        multi_gpus: bool = False,
150
        quantization_recipe: recipe.Recipe = None,
151
152
    ) -> jnp.ndarray:

153
154
155
156
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
157
158
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
159
160
        else:
            layernorm_input_axes = None
161
162
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
163

164
165
166
        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2, fp8_recipe=quantization_recipe
        )
167

168
169
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
170
            layernorm_mlp(
171
172
173
174
175
176
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
177
                norm_input_axes=layernorm_input_axes,
178
179
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
180
181
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
182
                activation_type=activation_type,
183
                quantizer_sets=quantizer_sets,
184
185
            )
        )
186

187
    def _test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
188
189
190
191
192
193
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
194
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
195
196
        use_shardy,
        with_jax_gemm,
197
    ):
198
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
199
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
200
        layernorm_type = "rmsnorm"
201

202
203
204
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
205
        static_inputs = [layernorm_type, activation_type]
206

Alp Dener's avatar
Alp Dener committed
207
208
209
        with use_jax_gemm(enabled=with_jax_gemm):
            value_and_grad_func = jax.value_and_grad(
                self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
210
211
            )

Alp Dener's avatar
Alp Dener committed
212
            # Single GPU
213
            with autocast(
214
                enabled=quantization_recipe is not None,
215
                recipe=quantization_recipe,
216
                mesh_resource=MeshResource(),
217
            ):
Alp Dener's avatar
Alp Dener committed
218
219
220
221
222
223
224
225
226
                single_jitter = jax.jit(
                    value_and_grad_func,
                    static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
                )
                single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

            # Multi GPUs
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
227
            with mesh, autocast(
228
                enabled=quantization_recipe is not None,
229
                recipe=quantization_recipe,
230
                mesh_resource=mesh_resource,
Alp Dener's avatar
Alp Dener committed
231
            ):
232
233
                k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
                k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
Alp Dener's avatar
Alp Dener committed
234
235
236
                k1_ = jax.device_put(k1, k1_sharding)
                k2_ = jax.device_put(k2, k2_sharding)
                if use_bias:
237
                    b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
Alp Dener's avatar
Alp Dener committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                    b1_ = jax.device_put(b1, b1_sharding)
                else:
                    b1_sharding = b1_ = None
                multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

                # Position ref for sharding pspec lists
                #   x, gamma, k1, k2, b1,
                #   b2
                in_shardings = (
                    None,
                    None,
                    k1_sharding,
                    k2_sharding,
                    b1_sharding,
                    None,
                )
                out_shardings = (
                    None,
                    (None, None, k1_sharding, k2_sharding, b1_sharding, None),
                )

                multi_jitter = jax.jit(
                    value_and_grad_func,
                    in_shardings=in_shardings,
                    out_shardings=out_shardings,
                    static_argnums=range(
                        len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
                    ),
                )  # +1 for multi_gpus

                multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

270
271
272
273
274
        fwd_test_type = bwd_test_type = dtype
        if quantization_recipe is not None:
            quantize_config = get_quantize_config_with_recipe(quantization_recipe)
            fwd_test_type = quantize_config.FWD_DTYPE
            bwd_test_type = quantize_config.BWD_DTYPE
275

276
277
278
279
        if fwd_test_type == jnp.float16 and use_bias:
            assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5)
        else:
            assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
280

281
282
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
283
284
285
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
286
                        assert_allclose(
Alp Dener's avatar
Alp Dener committed
287
288
289
290
                            m_grad,
                            s_grad,
                            dtype=bwd_test_type,
                            err_msg=f"multi_grads[{i}] is not close",
291
                        )
292
                else:
293
294
295
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
Alp Dener's avatar
Alp Dener committed
296
                        dtype=bwd_test_type,
297
298
299
                        err_msg=f"multi_grads[{i}] is not close",
                    )

300
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
301
302
303
304
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
305
    @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
306
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
307
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
308
309
310
311
312
313
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
314
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
315
        with_jax_gemm,
316
    ):
317
318
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
319
320
321
322
323
324
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
325
            quantization_recipe,
326
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
327
            with_jax_gemm=with_jax_gemm,
328
329
        )

330
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
331
332
333
334
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
335
    @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
336
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
337
    def test_layernorm_mlp_grad_shardy(
Alp Dener's avatar
Alp Dener committed
338
339
340
341
342
343
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
344
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
345
        with_jax_gemm,
346
    ):
347
348
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
349
350
351
352
353
354
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
355
            quantization_recipe=quantization_recipe,
356
            use_shardy=True,
Alp Dener's avatar
Alp Dener committed
357
            with_jax_gemm=with_jax_gemm,
358
359
        )

360
    def _test_layernorm_mlp(
361
362
363
364
365
366
367
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        use_fp8,
368
        quantization_recipe,
369
        use_shardy,
Alp Dener's avatar
Alp Dener committed
370
        with_jax_gemm,
371
    ):
372
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
373
        batch, seqlen, hidden_in = input_shape
374
        layernorm_type = "rmsnorm"
375
376

        rng = jax.random.PRNGKey(0)
377
        subkeys = jax.random.split(rng, 3)
378
379

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
380
        init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]}
381

Alp Dener's avatar
Alp Dener committed
382
383
        with use_jax_gemm(enabled=with_jax_gemm):
            # Single GPUs
384
385
            with autocast(
                enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource()
386
            ):
Alp Dener's avatar
Alp Dener committed
387
388
389
390
391
392
393
394
                ln_mlp_single = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    use_bias=use_bias,
                )
                params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
                mlp_out_single, ln_out_single = ln_mlp_single.apply(
395
                    params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
Alp Dener's avatar
Alp Dener committed
396
397
398
399
400
401
                )

            # Multi GPUs
            device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
402
403
            with mesh, autocast(
                enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
Alp Dener's avatar
Alp Dener committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            ):
                ln_mlp_sharded = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    scale_axes=LN_SCALE_AXES,
                    ln_bias_axes=LN_BIAS_AXES,
                    kernel_axes_1=KERNEL_1_AXES,
                    kernel_axes_2=KERNEL_2_AXES,
                    use_bias=use_bias,
                    bias_axes_1=BIAS_1_AXES,
                    bias_axes_2=BIAS_2_AXES,
                    layernorm_input_axes=LAYERNORM_INPUT_AXES,
                    dot_1_input_axes=DOT_1_INPUT_AXES,
                    dot_2_input_axes=DOT_2_INPUT_AXES,
                    name="mlp",
                )
                params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
                mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
423
                    params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
Alp Dener's avatar
Alp Dener committed
424
                )
425
426

        # Make sure params values are the same
427
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
428
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
429

430
        # TODO(Phuong): check if these tols updates are still needed
431
432
433
434
435
        atol = None
        rtol = None
        l40_tolerance_update = (
            get_min_device_compute_capability() == 89
            and use_fp8
436
            and quantization_recipe.delayed()
437
438
439
440
441
442
443
            and dtype == jnp.float16
            and activation_type == ("gelu",)
        )
        if l40_tolerance_update:
            atol = 0.04
            rtol = 11

444
445
446
447
448
449
450
451
452
453
        # JAX's FP8 GEMM, jax.lax.dot_general, now uses the
        # Triton backend by default. The error of
        # the Triton FP8 gemm has been verified to be less than or equal
        # to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
        # However, Triton can auto-tune a different kernel for the single GPU
        # and multi-GPU run in this test, meaning the diff between single GPU
        # and multi-GPU can be larger in some cases, even though both are
        # within tolerance to the float32 ground truth.
        jax_triton_gemm_precision_tolerance_update = (
            with_jax_gemm
454
455
            and quantization_recipe is not None
            and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling())
456
457
            and dtype in (jnp.bfloat16, jnp.float16)
            and activation_type == ("gelu", "linear"),
458
459
460
461
462
        )
        if jax_triton_gemm_precision_tolerance_update:
            atol = 0.08
            rtol = 15

463
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
464

465
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
466
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
467
468
469
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
470
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
471
    def test_layernorm_mlp_layer(
Alp Dener's avatar
Alp Dener committed
472
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
473
    ):
474
        self._test_layernorm_mlp(
475
476
477
478
479
480
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
481
            quantization_recipe=None,
Alp Dener's avatar
Alp Dener committed
482
483
            use_shardy=False,
            with_jax_gemm=with_jax_gemm,
484
        )
485

486
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
487
488
489
490
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
491
    @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
492
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
493
    def test_layernorm_mlp_layer_fp8(
494
495
496
497
498
499
500
501
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        quantization_recipe,
        with_jax_gemm,
502
    ):
503
504
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
505
506
507
508
509
510
511
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
512
            quantization_recipe=quantization_recipe,
513
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
514
515
516
517
            with_jax_gemm=with_jax_gemm,
        )

    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
518
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
533
            quantization_recipe=None,
Alp Dener's avatar
Alp Dener committed
534
535
536
537
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
        )

538
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
539
540
541
542
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
543
    @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
544
545
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_fp8_shardy(
546
547
548
549
550
551
552
553
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        quantization_recipe,
        with_jax_gemm,
Alp Dener's avatar
Alp Dener committed
554
    ):
555
556
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
Alp Dener's avatar
Alp Dener committed
557
558
559
560
561
562
563
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
564
            quantization_recipe=quantization_recipe,
Alp Dener's avatar
Alp Dener committed
565
566
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
567
        )