test_distributed_layernorm_mlp.py 13.8 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
5
from typing import Callable, List, Sequence, Union
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import HIDDEN_AXES, HIDDEN_TP_AXES, \
    BATCH_AXES, SEQLEN_TP_AXES, SEQLEN_AXES, \
    W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough

is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]]    # [batch, seqlen, hidden_in]

LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
INTERMEDIATE = 16

32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Only test with FSDP and TP as DP is not used
def generate_fsdp_and_tp_configs():
    configs = []
    if is_devices_enough(2):
        configs.append(
            [2, (1, 2), ('fsdp', 'tp'),
             MeshResource(fsdp_resource='fsdp', tp_resource='tp')])

    if is_devices_enough(4):
        configs.append(
            [4, (2, 2), ('fsdp', 'tp'),
             MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
    return configs


class TestDistributedLayernormMLP:

    def generate_inputs(self, input_shape, activation_type, use_bias, dtype):
        batch, seqlen, hidden_in = input_shape
        hidden_out = hidden_in

        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 6)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
59
60
61
62
        k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE),
                               dtype) / jnp.sqrt(hidden_in)
        k2 = jax.random.normal(subkeys[2],
                               (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE)
63
64
65
66
67
68
69
70
71
        if use_bias:
            b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
            b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
        else:
            b1 = None
            b2 = None

        return (x, gamma, k1, k2, b1, b2)

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    def layernorm_fp8_mlp_prim_func(
        self,
        x: jnp.ndarray,
        ln_scale: jnp.ndarray,
        kernel_1: jnp.ndarray,
        kernel_2: jnp.ndarray,
        bias_1: jnp.ndarray,
        bias_2: jnp.ndarray,
        amax_list_1: List[jnp.ndarray],
        amax_list_2: List[jnp.ndarray],
        scale_list_1: List[jnp.ndarray],
        scale_list_2: List[jnp.ndarray],
        layernorm_type: str = "rmsnorm",
        activation_type: Sequence[Union[str, Callable]] = ('gelu',),
        use_bias: bool = True,
        multi_gpus: bool = False,
    ) -> jnp.ndarray:

        fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1],
                                       scale_list_1[1], amax_list_1[2], scale_list_1[2])
        fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1],
                                       scale_list_2[1], amax_list_2[2], scale_list_2[2])
94
95
96
97
98
99
100
101
102
103
104
105

        if multi_gpus:
            layernorm_input_axes = LAYERNORM_INPUT_AXES
            dot_1_input_axes = DOT_1_INPUT_AXES
            dot_2_input_axes = DOT_2_INPUT_AXES
        else:
            layernorm_input_axes = None
            dot_1_input_axes = None
            dot_2_input_axes = None

        # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
        return jnp.mean(
106
107
108
109
            fused_layernorm_fp8_mlp(x,
                                    ln_scale,
                                    None, [kernel_1, kernel_2], [bias_1, bias_2],
                                    [fp8_meta_pkg1, fp8_meta_pkg2],
110
111
112
113
114
115
116
117
118
119
                                    layernorm_type,
                                    layernorm_input_axes=layernorm_input_axes,
                                    dot_1_input_axes=dot_1_input_axes,
                                    dot_2_input_axes=dot_2_input_axes,
                                    activation_type=activation_type,
                                    use_bias=use_bias))

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
    @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
120
    @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')])
121
122
    @pytest.mark.parametrize('dtype', DTYPES)
    @pytest.mark.parametrize('use_bias', [True, False])
123
124
    def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape,
                                         dtype):
125
126
127
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
        layernorm_type = 'rmsnorm'

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        fp8_amax_list_1 = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
        ]
        fp8_amax_list_2 = [
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
            jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
        ]
        fp8_scale_list_1 = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32)
        ]
        fp8_scale_list_2 = [
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32),
            jnp.ones((1,), jnp.float32)
        ]
148
149
150

        inputs = [x, gamma, k1, k2, b1, b2] = \
            self.generate_inputs(input_shape, activation_type, use_bias, dtype)
151
152
        inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
        static_inputs = [layernorm_type, activation_type, use_bias]
153
154
155
156
157
158
        value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func,
                                                 argnums=range(len(inputs)))

        # Single GPU
        single_jitter = jax.jit(value_and_grad_func,
                                static_argnums=range(len(inputs),
159
                                                     len(static_inputs) + len(inputs)))
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        with fp8_autocast(enabled=True):
            single_fwd, single_grads = single_jitter(*inputs, *static_inputs)

        # Multi GPUs
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
            k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp'))
            k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp'))
            k1_ = jax.device_put(k1, k1_sharding)
            k2_ = jax.device_put(k2, k2_sharding)
            if use_bias:
                b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp'))
                b1_ = jax.device_put(b1, b1_sharding)
            else:
                b1_sharding = b1_ = None
            multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]

            # Position ref for sharding pspec lists
            #   x, gamma, k1, k2, b1,
            #   b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
181
182
183
184
185
186
187
188
189
190
191
            in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
                            None, None)
            out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None,
                                    None, None, None))

            multi_jitter = jax.jit(value_and_grad_func,
                                   in_shardings=in_shardings,
                                   out_shardings=out_shardings,
                                   static_argnums=range(len(multi_inputs),
                                                        len(static_inputs) + len(multi_inputs) +
                                                        1))    # +1 for multi_gpus
192
193
194
195
196
197

            multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)

        assert_allclose(multi_fwd, single_fwd, dtype=dtype)
        for i in range(len(inputs)):
            if multi_grads[i] is not None:
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
                if isinstance(multi_grads[i], list):
                    assert isinstance(single_grads[i], list)
                    for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
                        assert_allclose(m_grad,
                                        s_grad,
                                        dtype=dtype,
                                        err_msg=f'multi_grads[{i}] is not close')
                else:
                    assert_allclose(multi_grads[i],
                                    single_grads[i],
                                    dtype=dtype,
                                    err_msg=f'multi_grads[{i}] is not close')

    def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype,
                            use_fp8):
213
214
215
216
217
218
219
220
221
222
223
224
225
        batch, seqlen, hidden_in = input_shape
        layernorm_type = 'rmsnorm'

        rng = jax.random.PRNGKey(0)
        subkeys = jax.random.split(rng, 2)

        x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
        init_rngs = {'params': subkeys[1]}

        # Single GPUs
        with fp8_autocast(enabled=use_fp8):
            ln_mlp_single = LayerNormMLP(
                layernorm_type=layernorm_type,
226
                transpose_batch_sequence=False,    # input: [batch, seqlen, hidden]
227
228
229
230
231
232
                intermediate_dim=INTERMEDIATE,
                activations=activation_type,
                dtype=dtype,
                use_bias=use_bias,
            )
            params_single = ln_mlp_single.init(init_rngs, x)
233
234
            mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single,
                                                                x,
235
236
237
238
239
240
241
                                                                deterministic=True)

        # Multi GPUs
        device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type,
                                          transpose_batch_sequence=False,
                                          intermediate_dim=INTERMEDIATE,
                                          activations=activation_type,
                                          dtype=dtype,
                                          scale_axes=(W_NO_SHARD_AXES,),
                                          ln_bias_axes=(W_NO_SHARD_AXES,),
                                          kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
                                          kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
                                          use_bias=use_bias,
                                          bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
                                          bias_axes_2=(W_NO_SHARD_AXES,),
                                          layernorm_input_axes=LAYERNORM_INPUT_AXES,
                                          dot_1_input_axes=DOT_1_INPUT_AXES,
                                          dot_2_input_axes=DOT_2_INPUT_AXES,
                                          name='mlp')
258
            params_sharded = ln_mlp_sharded.init(init_rngs, x)
259
260
            mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded,
                                                                   x,
261
262
263
264
265
266
267
268
269
                                                                   deterministic=True)

        # Make sure params values are the same
        assert_tree_like_allclose(params_sharded['params'], params_single['params'])
        assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
        assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)

    @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
    @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
270
    @pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')])
271
272
    @pytest.mark.parametrize('dtype', DTYPES)
    @pytest.mark.parametrize('use_bias', [True, False])
273
274
275
276
277
278
    def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
        self._test_layernorm_mlp(mesh_config,
                                 activation_type,
                                 use_bias,
                                 input_shape,
                                 dtype,
279
280
281
282
                                 use_fp8=False)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
283
    @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')])
284
285
286
    @pytest.mark.parametrize('use_bias', [True, False])
    @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPES)
287
288
289
290
291
292
293
    def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape,
                                     dtype):
        self._test_layernorm_mlp(mesh_config,
                                 activation_type,
                                 use_bias,
                                 input_shape,
                                 dtype,
294
                                 use_fp8=True)