test_distributed_layernorm_mlp.py 18 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Sequence, Union, Optional
5
6
7
8
9
10
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11
12
13
14
15
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
Alp Dener's avatar
Alp Dener committed
16
    use_jax_gemm,
17
)
18

19
20
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
21
22
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
23
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
24
from transformer_engine.jax.sharding import (
25
26
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
27
    BATCH_AXES,
28
29
30
31
32
33
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
34
)
35
from transformer_engine.jax.sharding import MeshResource
36
from transformer_engine.jax.quantize import QuantizerFactory
37
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
38

39
40

is_fp8_supported, reason = is_fp8_available()
41
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
42
43
44
45

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
46
    SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
47
48
49
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

50
DTYPES = [jnp.bfloat16, jnp.float16]
51
INPUT_SHAPE = [[4, 64, 128]]  # [batch, seqlen, hidden_in]
52
53
54
55

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
56
57
58
59
60
61
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
62
INTERMEDIATE = 64
63

64

65
66
67
68
69
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
    configs = []
    if is_devices_enough(2):
        configs.append(
70
71
            [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
72
73
    if is_devices_enough(4):
        configs.append(
74
75
            [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
90
        k1 = jax.random.normal(
91
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
92
93
94
95
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
96
        if use_bias:
97
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
98
99
100
101
102
103
104
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

105
106
107
108
109
110
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
111
112
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
113
        layernorm_type: str = "rmsnorm",
114
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
115
116
117
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

118
119
120
121
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
122
123
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
124
125
        else:
            layernorm_input_axes = None
126
127
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
128

129
130
        quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)

131
132
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
133
            layernorm_mlp(
134
135
136
137
138
139
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
140
                norm_input_axes=layernorm_input_axes,
141
142
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
143
144
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
145
                activation_type=activation_type,
146
                quantizer_sets=quantizer_sets,
147
148
            )
        )
149

150
    def _test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
151
152
153
154
155
156
157
158
159
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        use_shardy,
        with_jax_gemm,
160
    ):
161
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
162
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
163
        layernorm_type = "rmsnorm"
164

165
166
167
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
168
        static_inputs = [layernorm_type, activation_type]
169

Alp Dener's avatar
Alp Dener committed
170
171
172
        with use_jax_gemm(enabled=with_jax_gemm):
            value_and_grad_func = jax.value_and_grad(
                self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
173
174
            )

Alp Dener's avatar
Alp Dener committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            # Single GPU
            with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                single_jitter = jax.jit(
                    value_and_grad_func,
                    static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
                )
                single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

            # Multi GPUs
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
                enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
            ):
                k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
                k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
                k1_ = jax.device_put(k1, k1_sharding)
                k2_ = jax.device_put(k2, k2_sharding)
                if use_bias:
                    b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
                    b1_ = jax.device_put(b1, b1_sharding)
                else:
                    b1_sharding = b1_ = None
                multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

                # Position ref for sharding pspec lists
                #   x, gamma, k1, k2, b1,
                #   b2
                in_shardings = (
                    None,
                    None,
                    k1_sharding,
                    k2_sharding,
                    b1_sharding,
                    None,
                )
                out_shardings = (
                    None,
                    (None, None, k1_sharding, k2_sharding, b1_sharding, None),
                )

                multi_jitter = jax.jit(
                    value_and_grad_func,
                    in_shardings=in_shardings,
                    out_shardings=out_shardings,
                    static_argnums=range(
                        len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
                    ),
                )  # +1 for multi_gpus

                multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
        bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
        assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
230
231
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
232
233
234
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
235
                        assert_allclose(
Alp Dener's avatar
Alp Dener committed
236
237
238
239
                            m_grad,
                            s_grad,
                            dtype=bwd_test_type,
                            err_msg=f"multi_grads[{i}] is not close",
240
                        )
241
                else:
242
243
244
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
Alp Dener's avatar
Alp Dener committed
245
                        dtype=bwd_test_type,
246
247
248
                        err_msg=f"multi_grads[{i}] is not close",
                    )

249
250
251
252
253
254
255
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
256
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
257
    def test_layernorm_mlp_grad(
Alp Dener's avatar
Alp Dener committed
258
259
260
261
262
263
264
265
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
266
267
268
269
270
271
272
273
274
    ):
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            fp8_recipe,
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
275
            with_jax_gemm=with_jax_gemm,
276
277
278
279
280
281
282
283
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
284
285
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
286
    def test_layernorm_mlp_grad_shardy(
Alp Dener's avatar
Alp Dener committed
287
288
289
290
291
292
293
294
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        fp8_recipe,
        with_jax_gemm,
295
    ):
Alp Dener's avatar
Alp Dener committed
296
297
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
298
299
300
301
302
303
        self._test_layernorm_mlp_grad(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
Alp Dener's avatar
Alp Dener committed
304
            fp8_recipe=fp8_recipe,
305
            use_shardy=True,
Alp Dener's avatar
Alp Dener committed
306
            with_jax_gemm=with_jax_gemm,
307
308
        )

309
    def _test_layernorm_mlp(
310
311
312
313
314
315
316
317
318
        self,
        mesh_config,
        activation_type,
        use_bias,
        input_shape,
        dtype,
        use_fp8,
        fp8_recipe,
        use_shardy,
Alp Dener's avatar
Alp Dener committed
319
        with_jax_gemm,
320
    ):
321
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
322
        batch, seqlen, hidden_in = input_shape
323
        layernorm_type = "rmsnorm"
324
325
326
327
328

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
329
        init_rngs = {"params": subkeys[1]}
330

Alp Dener's avatar
Alp Dener committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        with use_jax_gemm(enabled=with_jax_gemm):
            # Single GPUs
            with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
                ln_mlp_single = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    transpose_batch_sequence=False,  # input: [batch, seqlen, hidden]
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    use_bias=use_bias,
                )
                params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
                mlp_out_single, ln_out_single = ln_mlp_single.apply(
                    params_single, x, deterministic=True
                )

            # Multi GPUs
            device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
            devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
            mesh = Mesh(devices, mesh_axes)
            with mesh, fp8_autocast(
                enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
            ):
                ln_mlp_sharded = LayerNormMLP(
                    layernorm_type=layernorm_type,
                    transpose_batch_sequence=False,
                    intermediate_dim=INTERMEDIATE,
                    activations=activation_type,
                    scale_axes=LN_SCALE_AXES,
                    ln_bias_axes=LN_BIAS_AXES,
                    kernel_axes_1=KERNEL_1_AXES,
                    kernel_axes_2=KERNEL_2_AXES,
                    use_bias=use_bias,
                    bias_axes_1=BIAS_1_AXES,
                    bias_axes_2=BIAS_2_AXES,
                    layernorm_input_axes=LAYERNORM_INPUT_AXES,
                    dot_1_input_axes=DOT_1_INPUT_AXES,
                    dot_2_input_axes=DOT_2_INPUT_AXES,
                    name="mlp",
                )
                params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
                mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                    params_sharded, x, deterministic=True
                )
374
375

        # Make sure params values are the same
376
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
377
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        atol = None
        rtol = None
        l40_tolerance_update = (
            get_min_device_compute_capability() == 89
            and fp8_recipe == recipe.DelayedScaling()
            and use_fp8
            and dtype == jnp.float16
            and activation_type == ("gelu",)
        )
        if l40_tolerance_update:
            atol = 0.04
            rtol = 11

        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol)
393

394
395
396
397
398
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
Alp Dener's avatar
Alp Dener committed
399
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
400
    def test_layernorm_mlp_layer(
Alp Dener's avatar
Alp Dener committed
401
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
402
    ):
403
        self._test_layernorm_mlp(
404
405
406
407
408
409
410
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
Alp Dener's avatar
Alp Dener committed
411
412
            use_shardy=False,
            with_jax_gemm=with_jax_gemm,
413
        )
414

415
416
417
418
419
420
421
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
Alp Dener's avatar
Alp Dener committed
422
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
423
    def test_layernorm_mlp_layer_fp8(
Alp Dener's avatar
Alp Dener committed
424
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
425
426
427
428
429
430
431
432
433
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
434
            use_shardy=False,
Alp Dener's avatar
Alp Dener committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            with_jax_gemm=with_jax_gemm,
        )

    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=False,
            fp8_recipe=None,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
        )

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
    @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
    def test_layernorm_mlp_layer_fp8_shardy(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
    ):
        if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
            pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
            use_shardy=True,
            with_jax_gemm=with_jax_gemm,
482
        )