test_distributed_layernorm.py 10.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import warnings
6
7
8
9
10
11
12
13
14
15
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
16
17
from utils import pytest_parametrize_wrapper

18
from transformer_engine.jax import fp8_autocast
19
from transformer_engine.common import recipe
20
from transformer_engine.jax.layernorm import layernorm
21
22
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available

23
24
25

DTYPES = [jnp.bfloat16, jnp.float32]

26
27
NORM_INPUT_SHAPES = {
    "L0": [[64, 64]],
28
    "L1": [[64, 64]],
29
30
31
32
    "L2": [[64, 64]],
}

is_fp8_supported, reason = is_fp8_available()
33
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
34
35
36
37

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
38
    SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
39
40
41
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

42
43
44

class TestDistributedLayernorm:

45
    def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
46
47
48
49
50
51
52
53
54
55
56
57
58
        weight_shape = (shape[-1],)

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        gamma = jnp.ones(weight_shape, dtype=dtype)
        beta = jnp.ones(weight_shape, dtype=dtype)

        if len(shape) == 2:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
        elif len(shape) == 3:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
        else:
            raise NotImplementedError

59
60
61
        g_pspec = b_pspec = (
            PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
        )
62
63
64

        return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

65
66
67
    def generate_collectives_count_ref(
        self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
    ):
68
69
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        is_dp_enabled = mesh_resource.dp_resource is not None
70
71
        assert ln_type in ["layernorm", "rmsnorm"]
        all_reduce_loss_bytes = 4  # 1 * FP32
72
        # for loss, dgamma and dbeta
73
74
        # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
        weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
75
76
77
        allreduce_total_bytes = (
            all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
        )
78
        other_bytes = 0
79
        if fp8_recipe == recipe.Float8CurrentScaling():
80
            allreduce_total_bytes += jax_dtype.itemsize  # 1 * dtype for the amax reduction
81
        return generate_collectives_count(
82
            allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
83
84
85
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
86
87
88
89
90
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
91
    @pytest_parametrize_wrapper("use_shardy", [False, True])
92
93
94
95
96
97
98
99
100
101
    def test_layernorm(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        zero_centered_gamma,
        shard_weights,
102
        fp8_recipe,
103
        use_shardy,
104
    ):
105
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
106
        epsilon = 1e-6
107
        ln_type = "layernorm"
108
        q_dtype = jnp.float8_e4m3fn
109
110

        def target_func(x, gamma, beta):
111
112
113
114
115
116
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(
                layernorm(
                    x, gamma, beta, ln_type, zero_centered_gamma, epsilon, quantizer=quantizer
                )
            )
117
118
119
120
121
122
123
124
125
126
127
128

        def ref_func(x, gamma, beta):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
            if zero_centered_gamma:
                output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
            else:
                output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
            return jnp.mean(output)

129
130
131
132
        (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
133
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
134
        )
135
136
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
137
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
138
139
140
141
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
            beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

142
143
            with warnings.catch_warnings(record=True) as warns:
                try:
144
145
146
147
148
149
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_, beta_],
                        collective_count_ref,
                        grad_args=(0, 1, 2),
150
151
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
152
153
154
                        in_shardings=(x_pspec, g_pspec, b_pspec),
                        out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
                    )
155
156
157
158
159
                except AssertionError as err:
                    # Layernorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma and/or beta. We can catch
                    # and ignore that specific error here.
160
161
162
                    if (
                        g_pspec[-1] is None and b_pspec[-1] is None
                    ) or "Expected collective count" not in str(err):
163
164
165
166
167
168
169
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "Layernorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )
170

171
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
172
173
174
175
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
176
    @pytest_parametrize_wrapper("use_shardy", [False, True])
177
    def test_rmsnorm(
178
179
180
181
182
183
184
185
186
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        shard_weights,
        fp8_recipe,
187
        use_shardy,
188
    ):
189
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
190
        epsilon = 1e-6
191
        ln_type = "rmsnorm"
192
        q_dtype = jnp.float8_e4m3fn
193
194

        def target_func(x, gamma):
195
196
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
197
198
199
200
201
202
203
204

        def ref_func(x, gamma):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
            output = y * gamma
            return jnp.mean(output)

205
206
207
208
        (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
209
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
210
        )
211
212
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
213
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
214
215
216
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

217
218
            with warnings.catch_warnings(record=True) as warns:
                try:
219
220
221
222
223
224
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_],
                        collective_count_ref,
                        grad_args=(0, 1),
225
226
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
227
228
229
                        in_shardings=(x_pspec, g_pspec),
                        out_shardings=(None, (x_pspec, g_pspec)),
                    )
230
231
232
233
234
235
236
237
238
239
240
241
242
                except AssertionError as err:
                    # RmsNorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma. We can catch
                    # and ignore that specific error here.
                    if g_pspec[-1] is None or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "RmsNorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )