test_distributed_softmax.py 7.57 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
6
import warnings
from functools import partial
7
import pytest
8
9
10
11
12
13
14
15
16
17

import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
18
from transformer_engine.jax import autocast
19
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
20
21
22
23
24
25
26
27

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSoftmax:

    def generate_collectives_count_ref(self):
        # for loss
28
        all_reduce_loss_bytes = 4  # 1 * FP32
29
30
        return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

31
    def generate_inputs(
32
        self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
33
    ):
34
35
36
        batch, _, sqelen, _ = shape

        x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
37
        if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
38
39
            mask = make_causal_mask(batch, sqelen)
        else:
40
            mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
41

42
        if not bad_sharding:
43
            x_pspec = PartitionSpec(
44
                mesh_resource.dp_resource, mesh_resource.tpsp_resource, None, None
45
            )
46
        else:
47
            x_pspec = PartitionSpec(
48
                mesh_resource.dp_resource, None, None, mesh_resource.tpsp_resource
49
            )
50
51
52
53
54

        if broadcast_batch_mask:
            mask_pspec = PartitionSpec(None, None, None, None)
        else:
            mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
55
56
57

        return (x, mask), (x_pspec, mask_pspec)

58
    @staticmethod
59
60
61
62
    def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
        return jnp.mean(
            softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
        )
63
64
65
66
67

    @staticmethod
    def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
        bias = None
        if mask is not None:
68
69
70
71
72
            bias = jax.lax.select(
                mask > 0,
                jnp.full(mask.shape, -1e10).astype(dtype),
                jnp.full(mask.shape, 0.0).astype(dtype),
            )
73
74
75
76
77
        if bias is not None:
            x = x + bias.astype(dtype)
        output = jax.nn.softmax(x * scale_factor)
        return jnp.mean(output)

78
    def impl_test_softmax(
79
80
81
82
83
84
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
85
        softmax_fusion_type,
86
87
88
        scale_factor,
        dtype,
        bad_sharding,
89
90
        broadcast_batch_mask,
        use_shardy,
91
    ):
92
        if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
93
            pytest.skip("Softmax type has no mask.")
94

95
        jax.config.update("jax_use_shardy_partitioner", use_shardy)
96
        target_func = partial(
97
            self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
98
        )
99
        ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
100

101
        (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
102
103
104
105
106
107
            data_shape,
            mesh_resource,
            softmax_fusion_type,
            dtype,
            bad_sharding,
            broadcast_batch_mask,
108
        )
109
110
111
        collective_count_ref = self.generate_collectives_count_ref()
        devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
        mesh = Mesh(devices, mesh_axes)
112
        with mesh, autocast(mesh_resource=mesh_resource):
113
114
115
116
            x_named_sharding = NamedSharding(mesh, x_pspec)
            mask_named_sharding = NamedSharding(mesh, mask_pspec)
            x_ = jax.device_put(x, x_named_sharding)
            mask_ = jax.device_put(mask, mask_named_sharding)
117

118
119
            with warnings.catch_warnings(record=True) as warns:
                try:
120
121
122
123
124
125
126
127
                    compare_ops(
                        target_func,
                        ref_func,
                        [x_, mask_],
                        collective_count_ref,
                        grad_args=(0,),
                        metric_fwd_dtype=dtype,
                        metric_bwd_dtype=dtype,
128
129
                        in_shardings=(x_named_sharding, mask_named_sharding),
                        out_shardings=(None, x_named_sharding),
130
                    )
131
132
133
134
135
136
137
138
139
140
141
142
                except AssertionError as err:
                    # Softmax should still produce the correct numerical result with
                    # bad sharding. However, the collective count may not be the same
                    # when XLA is forced to unshard the hidden dimension. We can catch
                    # and ignore that specific error here.
                    if not bad_sharding or "Expected collective count" not in str(err):
                        raise err
                finally:
                    for w in warns:
                        assert "Sharding the hidden dimension is not supported" in str(w), (
                            "Softmax primitive did not raise the correct warning for "
                            "unsupported sharding in the hidden dimension."
143
                            f"{str(w)}"
144
                        )
145
146

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
147
    @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
148
    @pytest.mark.parametrize(
149
150
151
152
153
154
        "softmax_fusion_type",
        [
            SoftmaxFusionType.SCALED,
            SoftmaxFusionType.SCALED_MASKED,
            SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
        ],
155
156
157
158
159
160
161
162
163
164
165
166
    )
    @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
    @pytest.mark.parametrize("dtype", DTYPES)
    @pytest.mark.parametrize("bad_sharding", [False, True])
    @pytest.mark.parametrize("broadcast_batch_mask", [False, True])
    def test_softmax(
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
        data_shape,
167
        softmax_fusion_type,
168
169
170
171
172
173
174
175
176
177
178
        scale_factor,
        dtype,
        bad_sharding,
        broadcast_batch_mask,
    ):
        self.impl_test_softmax(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape,
179
            softmax_fusion_type,
180
181
182
183
            scale_factor,
            dtype,
            bad_sharding,
            broadcast_batch_mask,
184
            use_shardy=True,
185
186
187
        )

    @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
188
189
190
    @pytest.mark.parametrize(
        "softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
    )
191
192
    @pytest.mark.parametrize("bad_sharding", [False, True])
    @pytest.mark.parametrize("broadcast_batch_mask", [False, True])
193
    def test_softmax_gspmd(
194
195
196
197
198
        self,
        device_count,
        mesh_shape,
        mesh_axes,
        mesh_resource,
199
        softmax_fusion_type,
200
201
202
203
204
205
206
207
208
        bad_sharding,
        broadcast_batch_mask,
    ):
        self.impl_test_softmax(
            device_count,
            mesh_shape,
            mesh_axes,
            mesh_resource,
            data_shape=[32, 12, 128, 128],
209
            softmax_fusion_type=softmax_fusion_type,
210
211
212
213
            scale_factor=1.0,
            dtype=DTYPES[0],
            bad_sharding=bad_sharding,
            broadcast_batch_mask=broadcast_batch_mask,
214
            use_shardy=False,
215
        )