test_softmax.py 8.71 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps

import jax
import jax.numpy as jnp
import pytest
from jax import lax
from jax import nn
from jax import value_and_grad, jit
from jax.typing import DTypeLike

from utils import assert_allclose

19
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
20
21
from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
22
from transformer_engine.jax.flax.module import Softmax
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def catch_unsupported(method):
    """
    The unsupported case should raise error instead of running it incorrectly.
    This helper function is to check if the unsupported case raises the assertion error.
    """

    @wraps(method)
    def wrapper(self, *args, **kwargs):
        if not self._is_support():
            assertion_checker = pytest.raises(AssertionError)
        else:
            assertion_checker = nullcontext()
        with assertion_checker:
            return method(self, *args, **kwargs)

    return wrapper


@dataclass
class SoftmaxRunner:
    """
    Softmax runner
    """
48

49
50
51
52
53
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads: int
    scale_factor: float
54
    softmax_fusion_type: SoftmaxFusionType
55
    dtype: DTypeLike
56
    softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
57
58
59
60
61
62
63

    @staticmethod
    def reference_softmax(logits, mask, scale_factor, **_):
        """
        Jax softmax as the reference
        """
        if mask is not None:
64
65
66
67
68
            logits += lax.select(
                mask > 0,
                jnp.full(mask.shape, -1e10).astype(logits.dtype),
                jnp.full(mask.shape, 0.0).astype(logits.dtype),
            )
69
70
71
        return nn.softmax(logits * scale_factor)

    def _is_support(self):
72
        return is_softmax_kernel_available(
73
            self.softmax_fusion_type,
74
75
76
77
78
79
80
            self.softmax_type,
            self.batch_size,
            self.num_heads,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.dtype,
        )
81
82
83
84
85
86
87
88

    def _setup_inputs(self):
        key = jax.random.PRNGKey(0)
        logits_key, mask_key = jax.random.split(key, 2)

        logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
        mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)

89
        self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
90

91
92
        match self.softmax_fusion_type:
            case SoftmaxFusionType.SCALED:
93
                self.mask = None
94
            case SoftmaxFusionType.SCALED_MASKED:
95
                self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
96
            case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
97
                self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
98
            case _:
99
                raise ValueError(f"Unknown {self.softmax_fusion_type=}")
100
101
102
103
104
105

    def test_forward(self):
        """
        Test transformer_engine.jax.softmax.softmax fwd rule
        """
        self._setup_inputs()
106
        primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)

    def test_backward(self):
        """
        Test transformer_engine.jax.softmax.softmax bwd rule
        """
        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
            fwd_out = func(*args, **kwargs)
            return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype)

        args = [self.logits, self.mask]
        kwargs = {
122
            "scale_factor": self.scale_factor,
123
            "softmax_fusion_type": self.softmax_fusion_type,
124
125
126
127
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
128
            value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
129
        )
130
131
        jitted_reference = jit(
            value_and_grad(
132
                lambda logits, *args: grad_func(
133
                    __class__.reference_softmax, logits, *args, **kwargs
134
135
136
137
                ),
                (0,),
            )
        )
138
139
140
141
142
143
144
145

        primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
        reference_out, (reference_grad_logits,) = jitted_reference(*args)

        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
        assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)


146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class SoftmaxPrimitivesRunner(SoftmaxRunner):
    """
    Jax Softmax Primitives runner
    """

    @catch_unsupported
    def test_forward(self):
        return super().test_forward()

    @catch_unsupported
    def test_backward(self):
        return super().test_backward()


class SoftmaxModuleRunner:
    """
    Jax Softmax Module runner
    """

    module_runner: SoftmaxRunner
    bias: None

    def __init__(self, module_runner, bias):
        self.module_runner = module_runner
        self.bias = bias

    def test_forward(self):
        """
        Test transformer_engine.jax.flax.module.Softmax fwd rule
        """
        runner = self.module_runner
        runner._setup_inputs()
        rng = jax.random.PRNGKey(0)
        softmax_module = Softmax(
            scale_factor=runner.scale_factor,
181
            softmax_fusion_type=runner.softmax_fusion_type,
182
183
184
185
186
187
188
189
        )
        softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
        module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
        reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
        assert_allclose(module_out, reference_out, dtype=runner.dtype)


# Run softmax primitives test
190
191
192
193
194
195
196
197
198
199
@pytest.mark.parametrize(
    "b, s_q, s_kv, h",
    [
        pytest.param(8, 16, 16, 16, id="8-16-16-16"),
        pytest.param(8, 512, 512, 16, id="8-512-512-16"),
        pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
    ],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
200
    "softmax_fusion_type",
201
    [
202
203
204
        pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
        pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
        pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
205
206
207
208
209
210
211
212
213
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(jnp.bfloat16, id="BF16"),
        pytest.param(jnp.float16, id="FP16"),
    ],
)
214
class TestSoftmaxPrimitives:
215
216
217
218
219
    """
    Test transformer_engine.jax.softmax.softmax
    """

    @staticmethod
220
    def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
221
222
223
        """
        Test forward with parameterized configs
        """
224
        runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
225
226
227
        runner.test_forward()

    @staticmethod
228
    def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
229
230
231
        """
        Test forward with parameterized configs
        """
232
        runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
233
        runner.test_backward()
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248


# Run Softmax module test
@pytest.mark.parametrize(
    "b, s_q, s_kv, h",
    [
        pytest.param(8, 16, 16, 16, id="8-16-16-16"),
        pytest.param(8, 512, 512, 16, id="8-512-512-16"),
        pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
        # triggers backup framework implementation due to (s_q % 4) != 0
        pytest.param(8, 511, 512, 16, id="8-511-512-16"),
    ],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
249
    "softmax_fusion_type",
250
    [
251
252
253
        pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
        pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
        pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(jnp.bfloat16, id="BF16"),
        pytest.param(jnp.float16, id="FP16"),
    ],
)
class TestSoftmaxModule:
    """
    Test transformer_engine.jax.flax.module.Softmax
    """

    @staticmethod
269
    def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
270
271
272
        """
        Test forward with parameterized configs
        """
273
        module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
274
275
276
        bias = None
        runner = SoftmaxModuleRunner(module_runner, bias)
        runner.test_forward()