test_softmax.py 8.36 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps

import jax
import jax.numpy as jnp
import pytest
from jax import lax
from jax import nn
from jax import value_and_grad, jit
from jax.typing import DTypeLike

from utils import assert_allclose

19
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
20
from transformer_engine.jax.softmax import SoftmaxType, softmax
21
from transformer_engine.jax.flax.module import Softmax
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def catch_unsupported(method):
    """
    The unsupported case should raise error instead of running it incorrectly.
    This helper function is to check if the unsupported case raises the assertion error.
    """

    @wraps(method)
    def wrapper(self, *args, **kwargs):
        if not self._is_support():
            assertion_checker = pytest.raises(AssertionError)
        else:
            assertion_checker = nullcontext()
        with assertion_checker:
            return method(self, *args, **kwargs)

    return wrapper


@dataclass
class SoftmaxRunner:
    """
    Softmax runner
    """
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads: int
    scale_factor: float
    softmax_type: SoftmaxType
    dtype: DTypeLike

    @staticmethod
    def reference_softmax(logits, mask, scale_factor, **_):
        """
        Jax softmax as the reference
        """
        if mask is not None:
62
63
64
65
66
            logits += lax.select(
                mask > 0,
                jnp.full(mask.shape, -1e10).astype(logits.dtype),
                jnp.full(mask.shape, 0.0).astype(logits.dtype),
            )
67
68
69
        return nn.softmax(logits * scale_factor)

    def _is_support(self):
70
71
72
73
74
75
76
77
        return is_softmax_kernel_available(
            self.softmax_type,
            self.batch_size,
            self.num_heads,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.dtype,
        )
78
79
80
81
82
83
84
85

    def _setup_inputs(self):
        key = jax.random.PRNGKey(0)
        logits_key, mask_key = jax.random.split(key, 2)

        logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
        mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)

86
        self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
87
88
89
90
91
92
93

        match self.softmax_type:
            case SoftmaxType.SCALED:
                self.mask = None
            case SoftmaxType.SCALED_MASKED:
                self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
            case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
94
                self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            case _:
                raise ValueError(f"Unknown {self.softmax_type=}")

    def test_forward(self):
        """
        Test transformer_engine.jax.softmax.softmax fwd rule
        """
        self._setup_inputs()
        primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
        reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)

    def test_backward(self):
        """
        Test transformer_engine.jax.softmax.softmax bwd rule
        """
        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
            fwd_out = func(*args, **kwargs)
            return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype)

        args = [self.logits, self.mask]
        kwargs = {
119
120
            "scale_factor": self.scale_factor,
            "softmax_type": self.softmax_type,
121
122
123
124
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
125
            value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
126
        )
127
128
        jitted_reference = jit(
            value_and_grad(
129
                lambda logits, *args: grad_func(
130
                    __class__.reference_softmax, logits, *args, **kwargs
131
132
133
134
                ),
                (0,),
            )
        )
135
136
137
138
139
140
141
142

        primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
        reference_out, (reference_grad_logits,) = jitted_reference(*args)

        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
        assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class SoftmaxPrimitivesRunner(SoftmaxRunner):
    """
    Jax Softmax Primitives runner
    """

    @catch_unsupported
    def test_forward(self):
        return super().test_forward()

    @catch_unsupported
    def test_backward(self):
        return super().test_backward()


class SoftmaxModuleRunner:
    """
    Jax Softmax Module runner
    """

    module_runner: SoftmaxRunner
    bias: None

    def __init__(self, module_runner, bias):
        self.module_runner = module_runner
        self.bias = bias

    def test_forward(self):
        """
        Test transformer_engine.jax.flax.module.Softmax fwd rule
        """
        runner = self.module_runner
        runner._setup_inputs()
        rng = jax.random.PRNGKey(0)
        softmax_module = Softmax(
            scale_factor=runner.scale_factor,
            softmax_type=runner.softmax_type,
        )
        softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
        module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
        reference_out = runner.reference_softmax(runner.logits, runner.mask, runner.scale_factor)
        assert_allclose(module_out, reference_out, dtype=runner.dtype)


# Run softmax primitives test
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
@pytest.mark.parametrize(
    "b, s_q, s_kv, h",
    [
        pytest.param(8, 16, 16, 16, id="8-16-16-16"),
        pytest.param(8, 512, 512, 16, id="8-512-512-16"),
        pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
    ],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(SoftmaxType.SCALED, id="SCALED"),
        pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
        pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(jnp.bfloat16, id="BF16"),
        pytest.param(jnp.float16, id="FP16"),
    ],
)
211
class TestSoftmaxPrimitives:
212
213
214
215
216
217
218
219
220
    """
    Test transformer_engine.jax.softmax.softmax
    """

    @staticmethod
    def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
        """
        Test forward with parameterized configs
        """
221
        runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
222
223
224
225
226
227
228
        runner.test_forward()

    @staticmethod
    def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
        """
        Test forward with parameterized configs
        """
229
        runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
230
        runner.test_backward()
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273


# Run Softmax module test
@pytest.mark.parametrize(
    "b, s_q, s_kv, h",
    [
        pytest.param(8, 16, 16, 16, id="8-16-16-16"),
        pytest.param(8, 512, 512, 16, id="8-512-512-16"),
        pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
        # triggers backup framework implementation due to (s_q % 4) != 0
        pytest.param(8, 511, 512, 16, id="8-511-512-16"),
    ],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(SoftmaxType.SCALED, id="SCALED"),
        pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
        pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(jnp.bfloat16, id="BF16"),
        pytest.param(jnp.float16, id="FP16"),
    ],
)
class TestSoftmaxModule:
    """
    Test transformer_engine.jax.flax.module.Softmax
    """

    @staticmethod
    def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
        """
        Test forward with parameterized configs
        """
        module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
        bias = None
        runner = SoftmaxModuleRunner(module_runner, bias)
        runner.test_forward()