test_distributed_layernorm_mlp.py 13.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
from typing import Callable, Sequence, Union, Optional
5
6
7
8
9
10
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11
12
13
14
15
16
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    is_devices_enough,
    pytest_parametrize_wrapper,
)
17

18
19
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
20
21
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
22
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
23
from transformer_engine.jax.sharding import (
24
25
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
26
    BATCH_AXES,
27
28
29
30
31
32
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
33
)
34
from transformer_engine.jax.sharding import MeshResource
35
from transformer_engine.jax.quantize import QuantizerFactory
36

37
38

is_fp8_supported, reason = is_fp8_available()
39
40
41
42
43
44
45
46
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

47
DTYPES = [jnp.bfloat16, jnp.float16]
48
INPUT_SHAPE = [[2, 64, 64]]  # [batch, seqlen, hidden_in]
49
50
51
52

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
53
INTERMEDIATE = 64
54

55

56
57
58
59
60
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
    configs = []
    if is_devices_enough(2):
        configs.append(
61
62
            [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
63
64
65

    if is_devices_enough(4):
        configs.append(
66
67
            [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
82
        k1 = jax.random.normal(
83
            subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype
84
85
86
87
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
88
        if use_bias:
89
            b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype)
90
91
92
93
94
95
96
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

97
98
99
100
101
102
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
103
104
        bias_1: Optional[jnp.ndarray],
        bias_2: Optional[jnp.ndarray],
105
        layernorm_type: str = "rmsnorm",
106
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
107
108
109
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

110
111
112
113
114
115
116
117
118
        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
        else:
            layernorm_input_axes = None
            dot_1_input_axes = None
            dot_2_input_axes = None

119
120
        quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)

121
122
        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
123
            layernorm_mlp(
124
125
126
127
128
129
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                layernorm_type,
130
                norm_input_axes=layernorm_input_axes,
131
132
133
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
                activation_type=activation_type,
134
                quantizer_sets=quantizer_sets,
135
136
            )
        )
137
138

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
139
140
141
142
143
144
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
145
    def test_layernorm_fp8_mlp_primitive(
146
        self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
147
    ):
148
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
149
        layernorm_type = "rmsnorm"
150

151
152
153
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
154
        static_inputs = [layernorm_type, activation_type]
155
156
157
        value_and_grad_func = jax.value_and_grad(
            self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
        )
158
159

        # Single GPU
160
161
162
163
164
        with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            single_jitter = jax.jit(
                value_and_grad_func,
                static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
            )
165
166
167
168
169
            single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

        # Multi GPUs
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
170
171
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
            k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp"))
172
            k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
173
174
175
            k1_ = jax.device_put(k1, k1_sharding)
            k2_ = jax.device_put(k2, k2_sharding)
            if use_bias:
176
                b1_sharding = NamedSharding(mesh, PartitionSpec("tp"))
177
178
179
180
181
182
183
                b1_ = jax.device_put(b1, b1_sharding)
            else:
                b1_sharding = b1_ = None
            multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

            # Position ref for sharding pspec lists
            #   x, gamma, k1, k2, b1,
184
            #   b2
185
186
187
188
189
190
191
192
193
194
            in_shardings = (
                None,
                None,
                k1_sharding,
                k2_sharding,
                b1_sharding,
                None,
            )
            out_shardings = (
                None,
195
                (None, None, k1_sharding, k2_sharding, b1_sharding, None),
196
197
198
199
200
201
202
203
            )

            multi_jitter = jax.jit(
                value_and_grad_func,
                in_shardings=in_shardings,
                out_shardings=out_shardings,
                static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
            )  # +1 for multi_gpus
204
205
206
207
208
209

            multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        assert_allclose(multi_fwd, single_fwd, dtype=dtype)
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
210
211
212
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
213
214
215
                        assert_allclose(
                            m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
                        )
216
                else:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                    is_gated = len(activation_type) > 1
                    rtol = None
                    atol = None
                    if is_gated:
                        if dtype == jnp.bfloat16:
                            if i == 2:
                                rtol = 800
                                atol = 9e-2
                            if i == 4:
                                atol = 300
                                rtol = 1e-1
                        if dtype == jnp.float16:
                            if i == 1:  # gamma
                                rtol = 200
                                atol = 1e-2
                            if i == 2:
                                rtol = 2000
                                atol = 7e-2
                            if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling():  # bias_1
                                # Accumulating dbias across a large tensor introduces a larger difference
                                rtol = 200
                                atol = 4e-2
                            if i == 4 and fp8_recipe == recipe.DelayedScaling():
                                rtol = 2200
                                atol = 9e-2
242
243
244
245
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
                        dtype=dtype,
246
247
                        rtol=rtol,
                        atol=atol,
248
249
250
251
                        err_msg=f"multi_grads[{i}] is not close",
                    )

    def _test_layernorm_mlp(
252
        self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None
253
    ):
254
        batch, seqlen, hidden_in = input_shape
255
        layernorm_type = "rmsnorm"
256
257
258
259
260

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
261
        init_rngs = {"params": subkeys[1]}
262
263

        # Single GPUs
264
        with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
265
266
            ln_mlp_single = LayerNormMLP(
                layernorm_type=layernorm_type,
267
                transpose_batch_sequence=False,  # input: [batch, seqlen, hidden]
268
269
270
271
272
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                use_bias=use_bias,
            )
            params_single = ln_mlp_single.init(init_rngs, x)
273
274
275
            mlp_out_single, ln_out_single = ln_mlp_single.apply(
                params_single, x, deterministic=True
            )
276
277
278
279
280

        # Multi GPUs
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
281
282
283
        with mesh, fp8_autocast(
            enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
        ):
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            ln_mlp_sharded = LayerNormMLP(
                layernorm_type=layernorm_type,
                transpose_batch_sequence=False,
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
                use_bias=use_bias,
                bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
                bias_axes_2=(W_NO_SHARD_AXES,),
                layernorm_input_axes=LAYERNORM_INPUT_AXES,
                dot_1_input_axes=DOT_1_INPUT_AXES,
                dot_2_input_axes=DOT_2_INPUT_AXES,
                name="mlp",
            )
301
            params_sharded = ln_mlp_sharded.init(init_rngs, x)
302
303
304
            mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                params_sharded, x, deterministic=True
            )
305
306

        # Make sure params values are the same
307
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
308
309
310
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)

311
312
313
314
315
    @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("use_bias", [True, False])
316
    def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
317
318
319
        self._test_layernorm_mlp(
            mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
        )
320

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    # TODO: debug
    # @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    # @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
    # @pytest_parametrize_wrapper(
    #     "activation_type", [("gelu",), ("gelu", "linear")]
    # )
    # @pytest_parametrize_wrapper("use_bias", [True, False])
    # @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
    # @pytest_parametrize_wrapper("dtype", DTYPES)
    # @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
    # def test_layernorm_fp8_mlp_layer(
    #     self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
    # ):
    #     self._test_layernorm_mlp(
    #         mesh_config, activation_type, use_bias, input_shape, dtype,
    #         use_fp8=True, fp8_recipe=fp8_recipe
    #     )