test_distributed_layernorm_mlp.py 20.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
import re
5
from typing import Callable, Sequence, Union, Optional
6
7
8
9
10
11
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
12
13
14
15
16
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
17
    use_jax_gemm,
18
)
19

20
from transformer_engine.common import recipe
21
22
23
24
25
from transformer_engine.jax.quantize import (
    is_fp8_available,
    ScalingMode,
    get_quantize_config_with_recipe,
)
26
from transformer_engine.jax import autocast
27
from transformer_engine.jax.flax import LayerNormMLP
28
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
29
from transformer_engine.jax.sharding import (
30
31
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
32
    BATCH_AXES,
33
34
35
36
37
38
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
39
)
40
from transformer_engine.jax.sharding import MeshResource
41
42
43
44
45
from transformer_engine.jax.quantize import (
    QuantizerFactory,
    get_supported_quantization_recipes,
    is_scaling_mode_supported,
)
46
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
47

48

49
50
51
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
52

53
54
SUPPORTED_RECIPES = get_supported_quantization_recipes()
SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES]
55

56
DTYPES = [jnp.bfloat16, jnp.float16]
57
INPUT_SHAPE = [[4, 128, 256]]  # [batch, seqlen, hidden_in]
58
59
60
61

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
62
63
64
65
66
67
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
68
INTERMEDIATE = 256
69

70

71
72
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
73
    configs = []
74
75
76
77
78
79
80
81
82
83
84
85
86
    if is_devices_enough(4):
        configs.append(
            pytest.param(
                [
                    4,
                    (2, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp2",
            )
        )

87
88
    if is_devices_enough(2):
        configs.append(
89
90
91
92
93
94
95
96
97
            pytest.param(
                [
                    2,
                    (1, 2),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp1_tpsp2",
            )
98
        )
99
        configs.append(
100
101
102
103
104
105
106
107
108
            pytest.param(
                [
                    2,
                    (2, 1),
                    ("fsdp", "tpsp"),
                    MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
                ],
                id="fsdp2_tpsp1",
            ),
109
        )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
124
        k1 = jax.random.normal(
125
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
126
127
128
129
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
130
        if use_bias:
131
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
132
133
134
135
136
137
138
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

139
140
141
142
143
144
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
145
146
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
147
        layernorm_type: str = "rmsnorm",
148
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
149
        multi_gpus: bool = False,
150
        quantization_recipe: recipe.Recipe = None,
151
152
    ) -> jnp.ndarray:

153
154
155
156
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
157
158
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
159
160
        else:
            layernorm_input_axes = None
161
162
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
163

164
165
166
        quantizer_sets = QuantizerFactory.create_set(
            n_quantizer_sets=2, fp8_recipe=quantization_recipe
        )
167

168
169
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
170
            layernorm_mlp(
171
172
173
174
175
176
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
177
                norm_input_axes=layernorm_input_axes,
178
179
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
180
181
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
182
                activation_type=activation_type,
183
                quantizer_sets=quantizer_sets,
184
185
            )
        )
186

187
    def _test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
188
189
190
191
192
193
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
194
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
195
196
        use_shardy,
        with_jax_gemm,
197
    ):
198
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
199
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
200
        layernorm_type = "rmsnorm"
201

202
203
204
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
205
        static_inputs = [layernorm_type, activation_type]
206

Alp Dener's avatar
Alp Dener committed
207
208
209
        with use_jax_gemm(enabled=with_jax_gemm):
            value_and_grad_func = jax.value_and_grad(
                self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
210
211
            )

Alp Dener's avatar
Alp Dener committed
212
            # Single GPU
213
            with autocast(
214
                enabled=quantization_recipe is not None,
215
                recipe=quantization_recipe,
216
                mesh_resource=MeshResource(),
217
            ):
Alp Dener's avatar
Alp Dener committed
218
219
220
221
222
223
224
225
226
                single_jitter = jax.jit(
                    value_and_grad_func,
                    static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
                )
                single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

            # Multi GPUs
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
227
            with mesh, autocast(
228
                enabled=quantization_recipe is not None,
229
                recipe=quantization_recipe,
230
                mesh_resource=mesh_resource,
Alp Dener's avatar
Alp Dener committed
231
            ):
232
233
                k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
                k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
Alp Dener's avatar
Alp Dener committed
234
235
236
                k1_ = jax.device_put(k1, k1_sharding)
                k2_ = jax.device_put(k2, k2_sharding)
                if use_bias:
237
                    b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
Alp Dener's avatar
Alp Dener committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                    b1_ = jax.device_put(b1, b1_sharding)
                else:
                    b1_sharding = b1_ = None
                multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

                # Position ref for sharding pspec lists
                #   x, gamma, k1, k2, b1,
                #   b2
                in_shardings = (
                    None,
                    None,
                    k1_sharding,
                    k2_sharding,
                    b1_sharding,
                    None,
                )
                out_shardings = (
                    None,
                    (None, None, k1_sharding, k2_sharding, b1_sharding, None),
                )

                multi_jitter = jax.jit(
                    value_and_grad_func,
                    in_shardings=in_shardings,
                    out_shardings=out_shardings,
                    static_argnums=range(
                        len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
                    ),
                )  # +1 for multi_gpus

                multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

270
271
272
273
274
        fwd_test_type = bwd_test_type = dtype
        if quantization_recipe is not None:
            quantize_config = get_quantize_config_with_recipe(quantization_recipe)
            fwd_test_type = quantize_config.FWD_DTYPE
            bwd_test_type = quantize_config.BWD_DTYPE
275

276
277
278
279
        if fwd_test_type == jnp.float16 and use_bias:
            assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5)
        else:
            assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
280

281
282
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
283
284
285
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
286
                        assert_allclose(
Alp Dener's avatar
Alp Dener committed
287
288
289
290
                            m_grad,
                            s_grad,
                            dtype=bwd_test_type,
                            err_msg=f"multi_grads[{i}] is not close",
291
                        )
292
                else:
293
294
295
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
Alp Dener's avatar
Alp Dener committed
296
                        dtype=bwd_test_type,
297
298
299
                        err_msg=f"multi_grads[{i}] is not close",
                    )

300
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
301
302
303
304
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
305
    @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
306
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
307
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
308
309
310
311
312
313
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
314
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
315
        with_jax_gemm,
316
    ):
317
318
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
319
320
321
322
323
324
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
325
            quantization_recipe,
326
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
327
            with_jax_gemm=with_jax_gemm,
328
329
        )

330
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
331
332
333
334
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
335
    @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
336
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
337
    def test_layernorm_mlp_grad_shardy(
Alp Dener's avatar
Alp Dener committed
338
339
340
341
342
343
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
344
        quantization_recipe,
Alp Dener's avatar
Alp Dener committed
345
        with_jax_gemm,
346
    ):
347
348
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
349
350
351
352
353
354
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
355
            quantization_recipe=quantization_recipe,
356
            use_shardy=True,
Alp Dener's avatar
Alp Dener committed
357
            with_jax_gemm=with_jax_gemm,
358
359
        )

360
    def _test_layernorm_mlp(
361
362
363
364
365
366
367
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        use_fp8,
368
        quantization_recipe,
369
        use_shardy,
Alp Dener's avatar
Alp Dener committed
370
        with_jax_gemm,
371
    ):
372
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
373
        batch, seqlen, hidden_in = input_shape
374
        layernorm_type = "rmsnorm"
375
376

        rng = jax.random.PRNGKey(0)
377
        subkeys = jax.random.split(rng, 3)
378
379

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
380
        init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]}
381

Alp Dener's avatar
Alp Dener committed
382
383
        with use_jax_gemm(enabled=with_jax_gemm):
            # Single GPUs
384
385
            with autocast(
                enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource()
386
            ):
Alp Dener's avatar
Alp Dener committed
387
388
389
390
391
                ln_mlp_single = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    use_bias=use_bias,
392
                    return_layernorm_output=True,
Alp Dener's avatar
Alp Dener committed
393
394
395
                )
                params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
                mlp_out_single, ln_out_single = ln_mlp_single.apply(
396
                    params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
Alp Dener's avatar
Alp Dener committed
397
398
399
400
401
402
                )

            # Multi GPUs
            device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
403
404
            with mesh, autocast(
                enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
Alp Dener's avatar
Alp Dener committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            ):
                ln_mlp_sharded = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    scale_axes=LN_SCALE_AXES,
                    ln_bias_axes=LN_BIAS_AXES,
                    kernel_axes_1=KERNEL_1_AXES,
                    kernel_axes_2=KERNEL_2_AXES,
                    use_bias=use_bias,
                    bias_axes_1=BIAS_1_AXES,
                    bias_axes_2=BIAS_2_AXES,
                    layernorm_input_axes=LAYERNORM_INPUT_AXES,
                    dot_1_input_axes=DOT_1_INPUT_AXES,
                    dot_2_input_axes=DOT_2_INPUT_AXES,
                    name="mlp",
421
                    return_layernorm_output=True,
Alp Dener's avatar
Alp Dener committed
422
423
424
                )
                params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
                mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
425
                    params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
Alp Dener's avatar
Alp Dener committed
426
                )
427
428

        # Make sure params values are the same
429
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
430
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
431

432
        # TODO(Phuong): check if these tols updates are still needed
433
434
435
436
437
        atol = None
        rtol = None
        l40_tolerance_update = (
            get_min_device_compute_capability() == 89
            and use_fp8
438
            and quantization_recipe.delayed()
439
440
441
442
443
444
445
            and dtype == jnp.float16
            and activation_type == ("gelu",)
        )
        if l40_tolerance_update:
            atol = 0.04
            rtol = 11

446
447
448
449
450
451
452
453
454
455
        # JAX's FP8 GEMM, jax.lax.dot_general, now uses the
        # Triton backend by default. The error of
        # the Triton FP8 gemm has been verified to be less than or equal
        # to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
        # However, Triton can auto-tune a different kernel for the single GPU
        # and multi-GPU run in this test, meaning the diff between single GPU
        # and multi-GPU can be larger in some cases, even though both are
        # within tolerance to the float32 ground truth.
        jax_triton_gemm_precision_tolerance_update = (
            with_jax_gemm
456
457
            and quantization_recipe is not None
            and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling())
458
459
            and dtype in (jnp.bfloat16, jnp.float16)
            and activation_type == ("gelu", "linear"),
460
461
462
463
464
        )
        if jax_triton_gemm_precision_tolerance_update:
            atol = 0.08
            rtol = 15

465
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
466

467
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
468
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
469
470
471
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
472
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
473
    def test_layernorm_mlp_layer(
Alp Dener's avatar
Alp Dener committed
474
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
475
    ):
476
        self._test_layernorm_mlp(
477
478
479
480
481
482
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
483
            quantization_recipe=None,
Alp Dener's avatar
Alp Dener committed
484
485
            use_shardy=False,
            with_jax_gemm=with_jax_gemm,
486
        )
487

488
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
489
490
491
492
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
493
    @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
494
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
495
    def test_layernorm_mlp_layer_fp8(
496
497
498
499
500
501
502
503
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        quantization_recipe,
        with_jax_gemm,
504
    ):
505
506
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
507
508
509
510
511
512
513
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
514
            quantization_recipe=quantization_recipe,
515
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
516
517
518
519
            with_jax_gemm=with_jax_gemm,
        )

    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
520
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
535
            quantization_recipe=None,
Alp Dener's avatar
Alp Dener committed
536
537
538
539
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
        )

540
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
541
542
543
544
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
545
    @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
546
547
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_fp8_shardy(
548
549
550
551
552
553
554
555
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        quantization_recipe,
        with_jax_gemm,
Alp Dener's avatar
Alp Dener committed
556
    ):
557
558
        if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
            pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
Alp Dener's avatar
Alp Dener committed
559
560
561
562
563
564
565
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
566
            quantization_recipe=quantization_recipe,
Alp Dener's avatar
Alp Dener committed
567
568
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
569
        )