test_distributed_layernorm_mlp.py 19 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Sequence, Union, Optional
5
6
7
8
9
10
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11
12
13
14
15
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
)
18

19
20
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
21
22
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
23
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
24
from transformer_engine.jax.sharding import (
25
26
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
27
    BATCH_AXES,
28
29
30
31
32
33
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
34
)
35
from transformer_engine.jax.sharding import MeshResource
36
from transformer_engine.jax.quantize import QuantizerFactory
37
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
38

39
40

is_fp8_supported, reason = is_fp8_available()
41
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
42
43
44
45

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
46
    SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
47
48
49
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

50
DTYPES = [jnp.bfloat16, jnp.float16]
51
INPUT_SHAPE = [[4, 64, 128]]  # [batch, seqlen, hidden_in]
52
53
54
55

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
56
57
58
59
60
61
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
62
INTERMEDIATE = 64
63

64

65
66
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
67
68
69
    configs = []
    if is_devices_enough(2):
        configs.append(
70
            [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
71
        )
72
73
    if is_devices_enough(4):
        configs.append(
74
            [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
75
        )
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
90
        k1 = jax.random.normal(
91
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
92
93
94
95
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
96
        if use_bias:
97
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
98
99
100
101
102
103
104
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

105
106
107
108
109
110
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
111
112
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
113
        layernorm_type: str = "rmsnorm",
114
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
115
116
117
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

118
119
120
121
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
122
123
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
124
125
        else:
            layernorm_input_axes = None
126
127
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
128

129
130
        quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)

131
132
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
133
            layernorm_mlp(
134
135
136
137
138
139
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
140
                norm_input_axes=layernorm_input_axes,
141
142
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
143
144
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
145
                activation_type=activation_type,
146
                quantizer_sets=quantizer_sets,
147
148
            )
        )
149

150
    def _test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
151
152
153
154
155
156
157
158
159
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        use_shardy,
        with_jax_gemm,
160
    ):
161
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
162
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
163
        layernorm_type = "rmsnorm"
164

165
166
167
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
168
        static_inputs = [layernorm_type, activation_type]
169

Alp Dener's avatar
Alp Dener committed
170
171
172
        with use_jax_gemm(enabled=with_jax_gemm):
            value_and_grad_func = jax.value_and_grad(
                self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
173
174
            )

Alp Dener's avatar
Alp Dener committed
175
            # Single GPU
176
177
178
            with fp8_autocast(
                enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
            ):
Alp Dener's avatar
Alp Dener committed
179
180
181
182
183
184
185
186
187
188
                single_jitter = jax.jit(
                    value_and_grad_func,
                    static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
                )
                single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

            # Multi GPUs
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
189
                enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
Alp Dener's avatar
Alp Dener committed
190
            ):
191
192
                k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
                k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
Alp Dener's avatar
Alp Dener committed
193
194
195
                k1_ = jax.device_put(k1, k1_sharding)
                k2_ = jax.device_put(k2, k2_sharding)
                if use_bias:
196
                    b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp"))
Alp Dener's avatar
Alp Dener committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
                    b1_ = jax.device_put(b1, b1_sharding)
                else:
                    b1_sharding = b1_ = None
                multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

                # Position ref for sharding pspec lists
                #   x, gamma, k1, k2, b1,
                #   b2
                in_shardings = (
                    None,
                    None,
                    k1_sharding,
                    k2_sharding,
                    b1_sharding,
                    None,
                )
                out_shardings = (
                    None,
                    (None, None, k1_sharding, k2_sharding, b1_sharding, None),
                )

                multi_jitter = jax.jit(
                    value_and_grad_func,
                    in_shardings=in_shardings,
                    out_shardings=out_shardings,
                    static_argnums=range(
                        len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
                    ),
                )  # +1 for multi_gpus

                multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
        bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
231
232
233
234
235
236

        if fwd_test_type == jnp.float16 and use_bias:
            assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
        else:
            assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)

237
238
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
239
240
241
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
242
                        assert_allclose(
Alp Dener's avatar
Alp Dener committed
243
244
245
246
                            m_grad,
                            s_grad,
                            dtype=bwd_test_type,
                            err_msg=f"multi_grads[{i}] is not close",
247
                        )
248
                else:
249
250
251
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
Alp Dener's avatar
Alp Dener committed
252
                        dtype=bwd_test_type,
253
254
255
                        err_msg=f"multi_grads[{i}] is not close",
                    )

256
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
257
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
258
259
260
261
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
262
    @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
263
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
264
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
265
266
267
268
269
270
271
272
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
273
274
275
276
277
278
279
280
281
    ):
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            fp8_recipe,
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
282
            with_jax_gemm=with_jax_gemm,
283
284
285
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
286
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
287
288
289
290
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
291
    @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
292
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
293
    def test_layernorm_mlp_grad_shardy(
Alp Dener's avatar
Alp Dener committed
294
295
296
297
298
299
300
301
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
302
    ):
Alp Dener's avatar
Alp Dener committed
303
304
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
305
306
307
308
309
310
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
Alp Dener's avatar
Alp Dener committed
311
            fp8_recipe=fp8_recipe,
312
            use_shardy=True,
Alp Dener's avatar
Alp Dener committed
313
            with_jax_gemm=with_jax_gemm,
314
315
        )

316
    def _test_layernorm_mlp(
317
318
319
320
321
322
323
324
325
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        use_fp8,
        fp8_recipe,
        use_shardy,
Alp Dener's avatar
Alp Dener committed
326
        with_jax_gemm,
327
    ):
328
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
329
        batch, seqlen, hidden_in = input_shape
330
        layernorm_type = "rmsnorm"
331
332
333
334
335

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
336
        init_rngs = {"params": subkeys[1]}
337

Alp Dener's avatar
Alp Dener committed
338
339
        with use_jax_gemm(enabled=with_jax_gemm):
            # Single GPUs
340
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
Alp Dener's avatar
Alp Dener committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
                ln_mlp_single = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    use_bias=use_bias,
                )
                params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
                mlp_out_single, ln_out_single = ln_mlp_single.apply(
                    params_single, x, deterministic=True
                )

            # Multi GPUs
            device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
                enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
            ):
                ln_mlp_sharded = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    scale_axes=LN_SCALE_AXES,
                    ln_bias_axes=LN_BIAS_AXES,
                    kernel_axes_1=KERNEL_1_AXES,
                    kernel_axes_2=KERNEL_2_AXES,
                    use_bias=use_bias,
                    bias_axes_1=BIAS_1_AXES,
                    bias_axes_2=BIAS_2_AXES,
                    layernorm_input_axes=LAYERNORM_INPUT_AXES,
                    dot_1_input_axes=DOT_1_INPUT_AXES,
                    dot_2_input_axes=DOT_2_INPUT_AXES,
                    name="mlp",
                )
                params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
                mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                    params_sharded, x, deterministic=True
                )
379
380

        # Make sure params values are the same
381
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
382
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        atol = None
        rtol = None
        l40_tolerance_update = (
            get_min_device_compute_capability() == 89
            and fp8_recipe == recipe.DelayedScaling()
            and use_fp8
            and dtype == jnp.float16
            and activation_type == ("gelu",)
        )
        if l40_tolerance_update:
            atol = 0.04
            rtol = 11

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        # JAX's FP8 GEMM, jax.lax.dot_general, now uses the
        # Triton backend by default. The error of
        # the Triton FP8 gemm has been verified to be less than or equal
        # to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
        # However, Triton can auto-tune a different kernel for the single GPU
        # and multi-GPU run in this test, meaning the diff between single GPU
        # and multi-GPU can be larger in some cases, even though both are
        # within tolerance to the float32 ground truth.
        jax_triton_gemm_precision_tolerance_update = (
            with_jax_gemm
            and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
            and dtype == jnp.bfloat16
            and activation_type == ("gelu", "linear")
        )
        if jax_triton_gemm_precision_tolerance_update:
            atol = 0.08
            rtol = 15

415
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
416

417
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
418
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
419
420
421
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
422
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
423
    def test_layernorm_mlp_layer(
Alp Dener's avatar
Alp Dener committed
424
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
425
    ):
426
        self._test_layernorm_mlp(
427
428
429
430
431
432
433
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
Alp Dener's avatar
Alp Dener committed
434
435
            use_shardy=False,
            with_jax_gemm=with_jax_gemm,
436
        )
437

438
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
439
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
440
441
442
443
444
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
445
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
446
    def test_layernorm_mlp_layer_fp8(
Alp Dener's avatar
Alp Dener committed
447
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
448
449
450
451
452
453
454
455
456
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
457
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
458
459
460
461
            with_jax_gemm=with_jax_gemm,
        )

    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
462
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
483
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
Alp Dener's avatar
Alp Dener committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_fp8_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
    ):
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
505
        )