test_fused_softmax.py 8.32 KB
Newer Older
Masaki Kozuki's avatar
Masaki Kozuki committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""Test for fused softmax functions.

Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
"""  # NOQA
import itertools
import unittest

import torch

from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax


def attention_mask_func(attention_scores, attention_mask):
15
    return attention_scores.masked_fill(attention_mask, -10000.0)
Masaki Kozuki's avatar
Masaki Kozuki committed
16
17


18
19
20
autocast_dtypes = (
    (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
Masaki Kozuki's avatar
Masaki Kozuki committed
21
22
23


class TestFusedScaleMaskSoftmax(unittest.TestCase):
24
25
26
27
28
29
30
31
    def _setup_fused_softmax(
        self,
        input_in_fp16,
        input_in_bf16,
        scale=None,
        softmax_in_fp32=False,
        attn_mask_type=AttnMaskType.padding,
    ):
Masaki Kozuki's avatar
Masaki Kozuki committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        fused_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=True,
        )
        torch_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=False,
        )
        return fused_fn, torch_fn

    def test_fused_scale_mask_softmax(self):
        """
        attention_scores.shape = [4, 12, 24, 24]
        mask.shape = [4, 1, 24, 24]
        """
        for (dtype, scale, softmax_in_fp32) in itertools.product(
58
            (torch.half, torch.bfloat16), (None, 2.0), (False, True),
Masaki Kozuki's avatar
Masaki Kozuki committed
59
60
61
62
63
64
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
65
66
67
68
69
70
71
                        self._setup_fused_softmax(
                            input_in_fp16,
                            input_in_bf16,
                            scale,
                            softmax_in_fp32,
                            AttnMaskType.padding,
                        )
Masaki Kozuki's avatar
Masaki Kozuki committed
72
                    return
73
74
75
76
77
78
79
80
81
82
83
84
85
                fused_fn, torch_fn = self._setup_fused_softmax(
                    input_in_fp16,
                    input_in_bf16,
                    scale,
                    softmax_in_fp32,
                    AttnMaskType.padding,
                )

                attention_scores_0 = (
                    torch.randn((4, 12, 24, 24))
                    .to(device="cuda", dtype=dtype)
                    .requires_grad_(True)
                )
86
87
                with torch.no_grad():
                    attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
Masaki Kozuki's avatar
Masaki Kozuki committed
88
                mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
89
90
                expected = fused_fn(attention_scores_0, mask)
                actual = torch_fn(attention_scores_1, mask)
91
                torch.testing.assert_close(actual, expected)
92
93
94
95
96
97

                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
98
99
100
101
102
103
104

    def test_autocast_fused_scale_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
105
106
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
                )
107

108
109
110
                attention_scores_0 = (
                    torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                )
111
                with torch.no_grad():
112
113
114
                    attention_scores_1 = (
                        attention_scores_0.clone().to(dtype).requires_grad_(True)
                    )
Masaki Kozuki's avatar
Masaki Kozuki committed
115
116
                mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()

117
                expected = torch_fn(attention_scores_1, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
118
                with torch.cuda.amp.autocast(dtype=dtype):
119
                    actual = fused_fn(attention_scores_0, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
120
                    self.assertEqual(actual.dtype, dtype)
121
                torch.testing.assert_close(actual, expected)
Masaki Kozuki's avatar
Masaki Kozuki committed
122

123
124
125
126
127
128
                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)

Masaki Kozuki's avatar
Masaki Kozuki committed
129
130
131
132
133
134
135
136
137
    def test_fused_upper_triangle_mask_softmax(self):
        """
        attn_weights.shape: [4, 12, 24, 24]
        total_mask.shape: [4, 1, 24, 24]

        total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
        upper elements are True and lower elements and diagonal are False.
        """
        for (dtype, scale, softmax_in_fp32) in itertools.product(
138
            (torch.half, torch.bfloat16), (None, 2.0), (False, True),
Masaki Kozuki's avatar
Masaki Kozuki committed
139
140
141
142
143
144
145
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
                        self._setup_fused_softmax(
146
147
148
149
150
151
                            input_in_fp16,
                            input_in_bf16,
                            scale,
                            softmax_in_fp32,
                            AttnMaskType.causal,
                        )
Masaki Kozuki's avatar
Masaki Kozuki committed
152
153
                    return
                fused_fn, torch_fn = self._setup_fused_softmax(
154
155
156
157
158
159
160
161
162
163
164
165
                    input_in_fp16,
                    input_in_bf16,
                    scale,
                    softmax_in_fp32,
                    AttnMaskType.causal,
                )

                attn_weights_0 = (
                    torch.randn((4, 12, 24, 24))
                    .to(device="cuda", dtype=dtype)
                    .requires_grad_(True)
                )
166
167
                with torch.no_grad():
                    attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
168
169
170
171
172
                total_mask = (
                    ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
                    .unsqueeze(0)
                    .unsqueeze(0)
                )
Masaki Kozuki's avatar
Masaki Kozuki committed
173
                total_mask = total_mask.repeat((4, 1, 1, 1))
174
175
                expected = fused_fn(attn_weights_0, total_mask)
                actual = torch_fn(attn_weights_1, total_mask)
176
                torch.testing.assert_close(actual, expected)
177
178
179
180
181
182

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
183
184
185
186
187
188
189

    def test_autocast_fused_upper_triangle_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
190
191
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
                )
192

193
194
195
                attn_weights_0 = (
                    torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                )
196
                with torch.no_grad():
197
198
199
200
201
202
203
204
                    attn_weights_1 = (
                        attn_weights_0.clone().to(dtype).requires_grad_(True)
                    )
                total_mask = (
                    ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
                    .unsqueeze(0)
                    .unsqueeze(0)
                )
Masaki Kozuki's avatar
Masaki Kozuki committed
205
206

                with torch.cuda.amp.autocast(dtype=dtype):
207
                    actual = fused_fn(attn_weights_0, total_mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
208
                    self.assertEqual(actual.dtype, dtype)
209
                expected = torch_fn(attn_weights_1, total_mask)
210
                torch.testing.assert_close(actual, expected)
211
212
213
214
215
216

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)