test_distributed_layernorm.py 10.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import warnings
6
7
8
9
10
11
12
13
14
15
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
16
17
from utils import pytest_parametrize_wrapper

18
from transformer_engine.jax import fp8_autocast
19
from transformer_engine.common import recipe
20
from transformer_engine.jax.layernorm import layernorm
21
22
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available

23
24
25

DTYPES = [jnp.bfloat16, jnp.float32]

26
27
28
29
30
31
NORM_INPUT_SHAPES = {
    "L0": [[64, 64]],
    "L2": [[64, 64]],
}

is_fp8_supported, reason = is_fp8_available()
32
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
33
34
35
36

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
37
    SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
38
39
40
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

41
42
43

class TestDistributedLayernorm:

44
    def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
45
46
47
48
49
50
51
52
53
54
55
56
57
        weight_shape = (shape[-1],)

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        gamma = jnp.ones(weight_shape, dtype=dtype)
        beta = jnp.ones(weight_shape, dtype=dtype)

        if len(shape) == 2:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
        elif len(shape) == 3:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
        else:
            raise NotImplementedError

58
59
60
        g_pspec = b_pspec = (
            PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
        )
61
62
63

        return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

64
65
66
    def generate_collectives_count_ref(
        self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
    ):
67
68
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        is_dp_enabled = mesh_resource.dp_resource is not None
69
70
        assert ln_type in ["layernorm", "rmsnorm"]
        all_reduce_loss_bytes = 4  # 1 * FP32
71
        # for loss, dgamma and dbeta
72
73
        # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
        weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
74
75
76
        allreduce_total_bytes = (
            all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
        )
77
78
79
        other_bytes = 0
        if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
            other_bytes = 384  # required for small scale shapes that require padding
80
        if fp8_recipe == recipe.Float8CurrentScaling():
81
            allreduce_total_bytes += jax_dtype.itemsize  # 1 * dtype for the amax reduction
82
        return generate_collectives_count(
83
            allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
84
85
86
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
87
88
89
90
91
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
92
    @pytest_parametrize_wrapper("use_shardy", [False, True])
93
94
95
96
97
98
99
100
101
102
    def test_layernorm(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        zero_centered_gamma,
        shard_weights,
103
        fp8_recipe,
104
        use_shardy,
105
    ):
106
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
107
        epsilon = 1e-6
108
        ln_type = "layernorm"
109
        q_dtype = jnp.float8_e4m3fn
110
111

        def target_func(x, gamma, beta):
112
113
114
115
116
117
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(
                layernorm(
                    x, gamma, beta, ln_type, zero_centered_gamma, epsilon, quantizer=quantizer
                )
            )
118
119
120
121
122
123
124
125
126
127
128
129

        def ref_func(x, gamma, beta):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
            if zero_centered_gamma:
                output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
            else:
                output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
            return jnp.mean(output)

130
131
132
133
        (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
134
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
135
        )
136
137
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
138
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
139
140
141
142
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
            beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

143
144
            with warnings.catch_warnings(record=True) as warns:
                try:
145
146
147
148
149
150
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_, beta_],
                        collective_count_ref,
                        grad_args=(0, 1, 2),
151
152
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
153
154
155
                        in_shardings=(x_pspec, g_pspec, b_pspec),
                        out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
                    )
156
157
158
159
160
                except AssertionError as err:
                    # Layernorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma and/or beta. We can catch
                    # and ignore that specific error here.
161
162
163
                    if (
                        g_pspec[-1] is None and b_pspec[-1] is None
                    ) or "Expected collective count" not in str(err):
164
165
166
167
168
169
170
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "Layernorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )
171

172
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
173
174
175
176
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
177
    @pytest_parametrize_wrapper("use_shardy", [False, True])
178
    def test_rmsnorm(
179
180
181
182
183
184
185
186
187
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        shard_weights,
        fp8_recipe,
188
        use_shardy,
189
    ):
190
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
191
        epsilon = 1e-6
192
        ln_type = "rmsnorm"
193
        q_dtype = jnp.float8_e4m3fn
194
195

        def target_func(x, gamma):
196
197
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
198
199
200
201
202
203
204
205

        def ref_func(x, gamma):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
            output = y * gamma
            return jnp.mean(output)

206
207
208
209
        (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
210
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
211
        )
212
213
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
214
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
215
216
217
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

218
219
            with warnings.catch_warnings(record=True) as warns:
                try:
220
221
222
223
224
225
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_],
                        collective_count_ref,
                        grad_args=(0, 1),
226
227
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
228
229
230
                        in_shardings=(x_pspec, g_pspec),
                        out_shardings=(None, (x_pspec, g_pspec)),
                    )
231
232
233
234
235
236
237
238
239
240
241
242
243
                except AssertionError as err:
                    # RmsNorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma. We can catch
                    # and ignore that specific error here.
                    if g_pspec[-1] is None or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "RmsNorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )