"cacheflow/vscode:/vscode.git/clone" did not exist on "84eee24e20ff4c0fc1b126289265f560089efa47"
test_distributed_layernorm.py 8.04 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import warnings
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import pytest

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.layernorm import layernorm

DTYPES = [jnp.bfloat16, jnp.float32]


class TestDistributedLayernorm:

24
    def generate_inputs(self, shape, mesh_resource, dtype, shard_weights):
25
26
27
28
29
30
31
32
33
34
35
36
37
        weight_shape = (shape[-1],)

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
        gamma = jnp.ones(weight_shape, dtype=dtype)
        beta = jnp.ones(weight_shape, dtype=dtype)

        if len(shape) == 2:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None)
        elif len(shape) == 3:
            x_pspec = PartitionSpec(mesh_resource.dp_resource, None, None)
        else:
            raise NotImplementedError

38
        g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

        return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)

    def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
        jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
        is_dp_enabled = mesh_resource.dp_resource is not None
        assert ln_type in ['layernorm', 'rmsnorm']
        all_reduce_loss_bytes = 4    # 1 * FP32
        # for loss, dgamma and dbeta
        weight_count = 2 if ln_type == 'layernorm' else 1
        allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
        return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
                                          allgather=0,
                                          other=0)

    @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
    @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
    @pytest.mark.parametrize('dtype', DTYPES)
    @pytest.mark.parametrize('zero_centered_gamma', [False, True])
58
    @pytest.mark.parametrize('shard_weights', [False, True])
59
    def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
60
                       zero_centered_gamma, shard_weights):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        epsilon = 1e-6
        ln_type = 'layernorm'

        def target_func(x, gamma, beta):
            return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))

        def ref_func(x, gamma, beta):
            x_ = jnp.asarray(x, jnp.float32)
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon)
            if zero_centered_gamma:
                output = jnp.asarray(normed_input * (gamma + 1) + beta).astype(x.dtype)
            else:
                output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
            return jnp.mean(output)

        (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
79
                self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
80
81
82
83
84
85
86
87
88
        collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
                                                                   data_shape, dtype)
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
            beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
            with warnings.catch_warnings(record=True) as warns:
                try:
                    compare_ops(target_func,
                                ref_func, [x_, gamma_, beta_],
                                collective_count_ref,
                                grad_args=(0, 1, 2),
                                metric_fwd_dtype=dtype,
                                metric_bwd_dtype=dtype,
                                in_shardings=(x_pspec, g_pspec, b_pspec),
                                out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
                except AssertionError as err:
                    # Layernorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma and/or beta. We can catch
                    # and ignore that specific error here.
                    if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "Layernorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )
112
113
114
115

    @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
    @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
    @pytest.mark.parametrize('dtype', DTYPES)
116
117
    @pytest.mark.parametrize('shard_weights', [False, True])
    def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        epsilon = 1e-6
        ln_type = 'rmsnorm'

        def target_func(x, gamma):
            return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))

        def ref_func(x, gamma):
            x = jnp.asarray(x, jnp.float32)
            mean2 = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), dtype)
            output = y * gamma
            return jnp.mean(output)

        (x, gamma, _), (x_pspec, g_pspec, _) = \
132
                self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
133
134
135
136
137
138
139
140
        collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
                                                                   data_shape, dtype)
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
        with mesh, fp8_autocast(mesh_resource=mesh_resource):
            x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
            gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            with warnings.catch_warnings(record=True) as warns:
                try:
                    compare_ops(target_func,
                                ref_func, [x_, gamma_],
                                collective_count_ref,
                                grad_args=(0, 1),
                                metric_fwd_dtype=dtype,
                                metric_bwd_dtype=dtype,
                                in_shardings=(x_pspec, g_pspec),
                                out_shardings=(None, (x_pspec, g_pspec)))
                except AssertionError as err:
                    # RmsNorm should still produce the correct numerical result with
                    # gamma/beta sharded. However, the collective count may not be the same
                    # when XLA is forced to unshard gamma. We can catch
                    # and ignore that specific error here.
                    if g_pspec[-1] is None or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Enforcing no sharding of parameters hidden dim!" in str(w), (
                            "RmsNorm primitive did not raise the correct warning for "
                            "unsupported sharding of gamma and/or beta"
                        )