test_distributed_layernorm_mlp.py 13.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Sequence, Union, Optional
5
6
7
8
9
10
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11
12
13
14
15
16
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
)
17

18
19
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
20
21
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
22
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
23
from transformer_engine.jax.sharding import (
24
25
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
26
    BATCH_AXES,
27
28
29
30
31
32
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
33
)
34
from transformer_engine.jax.sharding import MeshResource
35
from transformer_engine.jax.quantize import QuantizerFactory
36

37
38

is_fp8_supported, reason = is_fp8_available()
39
40
41
42
43
44
45
46
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

47
DTYPES = [jnp.bfloat16, jnp.float16]
48
INPUT_SHAPE = [[4, 64, 128]]  # [batch, seqlen, hidden_in]
49
50
51
52

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
53
54
55
56
57
58
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
59
INTERMEDIATE = 64
60

61

62
63
64
65
66
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
    configs = []
    if is_devices_enough(2):
        configs.append(
67
68
            [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
69
70
    if is_devices_enough(4):
        configs.append(
71
72
            [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
87
        k1 = jax.random.normal(
88
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
89
90
91
92
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
93
        if use_bias:
94
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
95
96
97
98
99
100
101
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

102
103
104
105
106
107
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
108
109
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
110
        layernorm_type: str = "rmsnorm",
111
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
112
113
114
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

115
116
117
118
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
119
120
            kernel_1_axes = KERNEL_1_AXES
            kernel_2_axes = KERNEL_2_AXES
121
122
        else:
            layernorm_input_axes = None
123
124
            dot_1_input_axes = dot_2_input_axes = None
            kernel_1_axes = kernel_2_axes = None
125

126
127
        quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)

128
129
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
130
            layernorm_mlp(
131
132
133
134
135
136
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
137
                norm_input_axes=layernorm_input_axes,
138
139
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
140
141
                kernel_1_axes=kernel_1_axes,
                kernel_2_axes=kernel_2_axes,
142
                activation_type=activation_type,
143
                quantizer_sets=quantizer_sets,
144
145
            )
        )
146
147

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
148
149
150
151
152
153
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
154
    def test_layernorm_mlp_grad(
155
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
156
    ):
157
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
158
        layernorm_type = "rmsnorm"
159

160
161
162
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
163
        static_inputs = [layernorm_type, activation_type]
164
165
166
        value_and_grad_func = jax.value_and_grad(
            self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
        )
167
168

        # Single GPU
169
170
171
172
173
        with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            single_jitter = jax.jit(
                value_and_grad_func,
                static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
            )
174
175
176
177
178
            single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

        # Multi GPUs
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
179
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
180
            k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
181
            k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
182
183
184
            k1_ = jax.device_put(k1, k1_sharding)
            k2_ = jax.device_put(k2, k2_sharding)
            if use_bias:
185
                b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
186
187
188
189
190
191
192
                b1_ = jax.device_put(b1, b1_sharding)
            else:
                b1_sharding = b1_ = None
            multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

            # Position ref for sharding pspec lists
            #   x, gamma, k1, k2, b1,
193
            #   b2
194
195
196
197
198
199
200
201
202
203
            in_shardings = (
                None,
                None,
                k1_sharding,
                k2_sharding,
                b1_sharding,
                None,
            )
            out_shardings = (
                None,
204
                (None, None, k1_sharding, k2_sharding, b1_sharding, None),
205
206
207
208
209
210
211
212
            )

            multi_jitter = jax.jit(
                value_and_grad_func,
                in_shardings=in_shardings,
                out_shardings=out_shardings,
                static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
            )  # +1 for multi_gpus
213
214
215
216
217
218

            multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        assert_allclose(multi_fwd, single_fwd, dtype=dtype)
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
219
220
221
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
222
223
224
                        assert_allclose(
                            m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
                        )
225
                else:
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
                    is_gated = len(activation_type) > 1
                    rtol = None
                    atol = None
                    if is_gated:
                        if dtype == jnp.bfloat16:
                            if i == 2:
                                rtol = 800
                                atol = 9e-2
                            if i == 4:
                                atol = 300
                                rtol = 1e-1
                        if dtype == jnp.float16:
                            if i == 1:  # gamma
                                rtol = 200
                                atol = 1e-2
                            if i == 2:
                                rtol = 2000
                                atol = 7e-2
                            if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling():  # bias_1
                                # Accumulating dbias across a large tensor introduces a larger difference
                                rtol = 200
                                atol = 4e-2
                            if i == 4 and fp8_recipe == recipe.DelayedScaling():
                                rtol = 2200
                                atol = 9e-2
251
252
253
254
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
                        dtype=dtype,
255
256
                        rtol=rtol,
                        atol=atol,
257
258
259
260
                        err_msg=f"multi_grads[{i}] is not close",
                    )

    def _test_layernorm_mlp(
261
        self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None
262
    ):
263
        batch, seqlen, hidden_in = input_shape
264
        layernorm_type = "rmsnorm"
265
266
267
268
269

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
270
        init_rngs = {"params": subkeys[1]}
271
272

        # Single GPUs
273
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
274
275
            ln_mlp_single = LayerNormMLP(
                layernorm_type=layernorm_type,
276
                transpose_batch_sequence=False,  # input: [batch, seqlen, hidden]
277
278
279
280
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                use_bias=use_bias,
            )
281
            params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
282
283
284
            mlp_out_single, ln_out_single = ln_mlp_single.apply(
                params_single, x, deterministic=True
            )
285
286
287
288
289

        # Multi GPUs
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
290
291
292
        with mesh, fp8_autocast(
            enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
        ):
293
294
295
296
297
            ln_mlp_sharded = LayerNormMLP(
                layernorm_type=layernorm_type,
                transpose_batch_sequence=False,
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
298
299
300
301
                scale_axes=LN_SCALE_AXES,
                ln_bias_axes=LN_BIAS_AXES,
                kernel_axes_1=KERNEL_1_AXES,
                kernel_axes_2=KERNEL_2_AXES,
302
                use_bias=use_bias,
303
304
                bias_axes_1=BIAS_1_AXES,
                bias_axes_2=BIAS_2_AXES,
305
306
307
308
309
                layernorm_input_axes=LAYERNORM_INPUT_AXES,
                dot_1_input_axes=DOT_1_INPUT_AXES,
                dot_2_input_axes=DOT_2_INPUT_AXES,
                name="mlp",
            )
310
            params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
311
312
313
            mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                params_sharded, x, deterministic=True
            )
314
315

        # Make sure params values are the same
316
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
317
318
319
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)

320
321
322
323
324
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
325
    def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
326
327
328
        self._test_layernorm_mlp(
            mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
        )
329

330
331
332
333
334
335
336
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
337
    def test_layernorm_mlp_layer_fp8(
338
339
340
341
342
343
344
345
346
347
348
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
    ):
        self._test_layernorm_mlp(
            mesh_config,
            activation_type,
            use_bias,
            input_shape,
            dtype,
            use_fp8=True,
            fp8_recipe=fp8_recipe,
        )