test_distributed_layernorm_mlp.py 12.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
import pytest
5
from typing import Callable, List, Sequence, Union
6
7
8
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
16
17
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
18
19
    HIDDEN_AXES,
    HIDDEN_TP_AXES,
20
    BATCH_AXES,
21
22
23
24
25
26
    SEQLEN_TP_AXES,
    SEQLEN_AXES,
    W_NO_SHARD_AXES,
    W_FSDP_AXES,
    W_TP_AXES,
    W_JOINED_AXES,
27
)
28
from transformer_engine.jax.sharding import MeshResource
29

30
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
31
32
33

is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
34
INPUT_SHAPE = [[64, 128, 32]]  # [batch, seqlen, hidden_in]
35
36
37
38
39
40

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
INTERMEDIATE = 16

41

42
43
44
45
46
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
    configs = []
    if is_devices_enough(2):
        configs.append(
47
48
            [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
49
50
51

    if is_devices_enough(4):
        configs.append(
52
53
            [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
        )
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
68
69
70
71
72
73
        k1 = jax.random.normal(
            subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
        ) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
            INTERMEDIATE
        )
74
75
76
77
78
79
80
81
82
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

83
84
85
86
87
88
89
90
91
92
93
94
95
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
        bias_1: jnp.ndarray,
        bias_2: jnp.ndarray,
        amax_list_1: List[jnp.ndarray],
        amax_list_2: List[jnp.ndarray],
        scale_list_1: List[jnp.ndarray],
        scale_list_2: List[jnp.ndarray],
        layernorm_type: str = "rmsnorm",
96
        activation_type: Sequence[Union[str, Callable]] = ("gelu",),
97
98
99
100
        use_bias: bool = True,
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        fp8_meta_pkg1 = FP8MetaPackage(
            amax_list_1[0],
            scale_list_1[0],
            amax_list_1[1],
            scale_list_1[1],
            amax_list_1[2],
            scale_list_1[2],
        )
        fp8_meta_pkg2 = FP8MetaPackage(
            amax_list_2[0],
            scale_list_2[0],
            amax_list_2[1],
            scale_list_2[1],
            amax_list_2[2],
            scale_list_2[2],
        )
117
118
119
120
121
122
123
124
125
126
127
128

        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
        else:
            layernorm_input_axes = None
            dot_1_input_axes = None
            dot_2_input_axes = None

        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            fused_layernorm_fp8_mlp(
                x,
                ln_scale,
                None,
                [kernel_1, kernel_2],
                [bias_1, bias_2],
                [fp8_meta_pkg1, fp8_meta_pkg2],
                layernorm_type,
                layernorm_input_axes=layernorm_input_axes,
                dot_1_input_axes=dot_1_input_axes,
                dot_2_input_axes=dot_2_input_axes,
                activation_type=activation_type,
                use_bias=use_bias,
            )
        )
144
145

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
146
147
148
149
150
151
152
153
    @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
    @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("use_bias", [True, False])
    def test_layernorm_fp8_mlp_primitive(
        self, mesh_config, activation_type, use_bias, input_shape, dtype
    ):
154
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
155
        layernorm_type = "rmsnorm"
156

157
158
159
        fp8_amax_list_1 = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
160
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
161
162
163
164
        ]
        fp8_amax_list_2 = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
165
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
166
167
168
169
        ]
        fp8_scale_list_1 = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
170
            jnp.ones((1,), jnp.float32),
171
172
173
174
        ]
        fp8_scale_list_2 = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
175
            jnp.ones((1,), jnp.float32),
176
        ]
177

178
179
180
        inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
            input_shape, activation_type, use_bias, dtype
        )
181
182
        inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
        static_inputs = [layernorm_type, activation_type, use_bias]
183
184
185
        value_and_grad_func = jax.value_and_grad(
            self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
        )
186
187

        # Single GPU
188
189
190
        single_jitter = jax.jit(
            value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
        )
191
192
193
194
195
196
197
        with fp8_autocast(enabled=True):
            single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

        # Multi GPUs
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
198
199
            k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
            k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
200
201
202
            k1_ = jax.device_put(k1, k1_sharding)
            k2_ = jax.device_put(k2, k2_sharding)
            if use_bias:
203
                b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
204
205
206
207
208
209
210
211
                b1_ = jax.device_put(b1, b1_sharding)
            else:
                b1_sharding = b1_ = None
            multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

            # Position ref for sharding pspec lists
            #   x, gamma, k1, k2, b1,
            #   b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            in_shardings = (
                None,
                None,
                k1_sharding,
                k2_sharding,
                b1_sharding,
                None,
                None,
                None,
                None,
                None,
            )
            out_shardings = (
                None,
                (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
            )

            multi_jitter = jax.jit(
                value_and_grad_func,
                in_shardings=in_shardings,
                out_shardings=out_shardings,
                static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
            )  # +1 for multi_gpus
235
236
237
238
239
240

            multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        assert_allclose(multi_fwd, single_fwd, dtype=dtype)
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
241
242
243
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
244
245
246
                        assert_allclose(
                            m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
                        )
247
                else:
248
249
250
251
252
253
254
255
256
257
                    assert_allclose(
                        multi_grads[i],
                        single_grads[i],
                        dtype=dtype,
                        err_msg=f"multi_grads[{i}] is not close",
                    )

    def _test_layernorm_mlp(
        self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
    ):
258
        batch, seqlen, hidden_in = input_shape
259
        layernorm_type = "rmsnorm"
260
261
262
263
264

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
265
        init_rngs = {"params": subkeys[1]}
266
267
268
269
270

        # Single GPUs
        with fp8_autocast(enabled=use_fp8):
            ln_mlp_single = LayerNormMLP(
                layernorm_type=layernorm_type,
271
                transpose_batch_sequence=False,  # input: [batch, seqlen, hidden]
272
273
274
275
276
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                use_bias=use_bias,
            )
            params_single = ln_mlp_single.init(init_rngs, x)
277
278
279
            mlp_out_single, ln_out_single = ln_mlp_single.apply(
                params_single, x, deterministic=True
            )
280
281
282
283
284
285

        # Multi GPUs
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            ln_mlp_sharded = LayerNormMLP(
                layernorm_type=layernorm_type,
                transpose_batch_sequence=False,
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                scale_axes=(W_NO_SHARD_AXES,),
                ln_bias_axes=(W_NO_SHARD_AXES,),
                kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
                use_bias=use_bias,
                bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
                bias_axes_2=(W_NO_SHARD_AXES,),
                layernorm_input_axes=LAYERNORM_INPUT_AXES,
                dot_1_input_axes=DOT_1_INPUT_AXES,
                dot_2_input_axes=DOT_2_INPUT_AXES,
                name="mlp",
            )
303
            params_sharded = ln_mlp_sharded.init(init_rngs, x)
304
305
306
            mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
                params_sharded, x, deterministic=True
            )
307
308

        # Make sure params values are the same
309
        assert_tree_like_allclose(params_sharded["params"], params_single["params"])
310
311
312
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)

313
314
315
316
317
    @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
    @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
    @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("use_bias", [True, False])
318
    def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
319
320
321
        self._test_layernorm_mlp(
            mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
        )
322
323

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
324
325
326
327
328
329
330
331
332
333
334
    @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
    @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
    @pytest.mark.parametrize("use_bias", [True, False])
    @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPES)
    def test_layernorm_fp8_mlp_layer(
        self, mesh_config, activation_type, use_bias, input_shape, dtype
    ):
        self._test_layernorm_mlp(
            mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
        )