test_distributed_layernorm.py 9.68 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import warnings
6
7
8
9
10
11
12
13
14
15
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
16
17
from utils import pytest_parametrize_wrapper

18
from transformer_engine.jax import fp8_autocast
19
from transformer_engine.common import recipe
20
from transformer_engine.jax.layernorm import layernorm
21
22
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available

23
24
25

DTYPES = [jnp.bfloat16, jnp.float32]

26
27
28
29
30
31
NORM_INPUT_SHAPES = {
    "L0": [[64, 64]],
    "L2": [[64, 64]],
}

is_fp8_supported, reason = is_fp8_available()
32
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
33
34
35
36
37
38
39

SUPPORTED_RECIPES = []
if is_fp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))

40
41
42

class TestDistributedLayernorm:

43
    def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
44
45
46
47
48
49
50
51
52
53
54
55
56
        weight_shape = (shape[-1],)

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        gamma = jnp.ones(weight_shape, dtype=dtype)
        beta = jnp.ones(weight_shape, dtype=dtype)

        if len(shape) == 2:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
        elif len(shape) == 3:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
        else:
            raise NotImplementedError

57
58
59
        g_pspec = b_pspec = (
            PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
        )
60
61
62

        return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

63
64
65
    def generate_collectives_count_ref(
        self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
    ):
66
67
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        is_dp_enabled = mesh_resource.dp_resource is not None
68
69
        assert ln_type in ["layernorm", "rmsnorm"]
        all_reduce_loss_bytes = 4  # 1 * FP32
70
        # for loss, dgamma and dbeta
71
72
        # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
        weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
73
74
75
        allreduce_total_bytes = (
            all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
        )
76
77
78
        other_bytes = 0
        if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
            other_bytes = 384  # required for small scale shapes that require padding
79
        return generate_collectives_count(
80
            allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
81
82
83
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
84
85
86
87
88
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
89
90
91
92
93
94
95
96
97
98
    def test_layernorm(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        zero_centered_gamma,
        shard_weights,
99
        fp8_recipe,
100
    ):
101
        epsilon = 1e-6
102
        ln_type = "layernorm"
103
        q_dtype = jnp.float8_e4m3fn
104
105

        def target_func(x, gamma, beta):
106
107
108
109
110
111
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(
                layernorm(
                    x, gamma, beta, ln_type, zero_centered_gamma, epsilon, quantizer=quantizer
                )
            )
112
113
114
115
116
117
118
119
120
121
122
123

        def ref_func(x, gamma, beta):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
            if zero_centered_gamma:
                output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
            else:
                output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
            return jnp.mean(output)

124
125
126
127
        (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
128
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
129
        )
130
131
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
132
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
133
134
135
136
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
            beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

137
138
            with warnings.catch_warnings(record=True) as warns:
                try:
139
140
141
142
143
144
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_, beta_],
                        collective_count_ref,
                        grad_args=(0, 1, 2),
145
146
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
147
148
149
                        in_shardings=(x_pspec, g_pspec, b_pspec),
                        out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
                    )
150
151
152
153
154
                except AssertionError as err:
                    # Layernorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma and/or beta. We can catch
                    # and ignore that specific error here.
155
156
157
                    if (
                        g_pspec[-1] is None and b_pspec[-1] is None
                    ) or "Expected collective count" not in str(err):
158
159
160
161
162
163
164
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "Layernorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )
165

166
    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
167
168
169
170
    @pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
    @pytest_parametrize_wrapper("dtype", DTYPES)
    @pytest_parametrize_wrapper("shard_weights", [False, True])
    @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
171
    def test_rmsnorm(
172
173
174
175
176
177
178
179
180
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
        dtype,
        shard_weights,
        fp8_recipe,
181
    ):
182
        epsilon = 1e-6
183
        ln_type = "rmsnorm"
184
        q_dtype = jnp.float8_e4m3fn
185
186

        def target_func(x, gamma):
187
188
            quantizer = QuantizerFactory.create_set().x
            return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
189
190
191
192
193
194
195
196

        def ref_func(x, gamma):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
            output = y * gamma
            return jnp.mean(output)

197
198
199
200
        (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
            data_shape, mesh_resource, dtype, shard_weights
        )
        collective_count_ref = self.generate_collectives_count_ref(
201
            mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
202
        )
203
204
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
205
        with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
206
207
208
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

209
210
            with warnings.catch_warnings(record=True) as warns:
                try:
211
212
213
214
215
216
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, gamma_],
                        collective_count_ref,
                        grad_args=(0, 1),
217
218
                        metric_fwd_dtype=q_dtype,
                        metric_bwd_dtype=q_dtype,
219
220
221
                        in_shardings=(x_pspec, g_pspec),
                        out_shardings=(None, (x_pspec, g_pspec)),
                    )
222
223
224
225
226
227
228
229
230
231
232
233
234
                except AssertionError as err:
                    # RmsNorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma. We can catch
                    # and ignore that specific error here.
                    if g_pspec[-1] is None or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "RmsNorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )