test_fused_softmax.py 8.46 KB
Newer Older
Masaki Kozuki's avatar
Masaki Kozuki committed
1
2
3
4
5
6
7
"""Test for fused softmax functions.

Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
"""  # NOQA
import itertools

import torch
Aidyn-A's avatar
Aidyn-A committed
8
from torch.testing._internal import common_utils
Masaki Kozuki's avatar
Masaki Kozuki committed
9
10
11
12
13
14

from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax


def attention_mask_func(attention_scores, attention_mask):
15
    return attention_scores.masked_fill(attention_mask, -10000.0)
Masaki Kozuki's avatar
Masaki Kozuki committed
16
17


18
19
20
autocast_dtypes = (
    (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
)
Masaki Kozuki's avatar
Masaki Kozuki committed
21
22


Aidyn-A's avatar
Aidyn-A committed
23
class TestFusedScaleMaskSoftmax(common_utils.TestCase):
24
25
26
27
28
29
30
31
    def _setup_fused_softmax(
        self,
        input_in_fp16,
        input_in_bf16,
        scale=None,
        softmax_in_fp32=False,
        attn_mask_type=AttnMaskType.padding,
    ):
Masaki Kozuki's avatar
Masaki Kozuki committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        fused_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=True,
        )
        torch_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=False,
        )
        return fused_fn, torch_fn

    def test_fused_scale_mask_softmax(self):
        """
        attention_scores.shape = [4, 12, 24, 24]
        mask.shape = [4, 1, 24, 24]
        """
57
58
        for (dtype, scale, softmax_in_fp32, shape) in itertools.product(
            (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214))
Masaki Kozuki's avatar
Masaki Kozuki committed
59
60
61
62
63
64
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
65
66
67
68
69
70
71
                        self._setup_fused_softmax(
                            input_in_fp16,
                            input_in_bf16,
                            scale,
                            softmax_in_fp32,
                            AttnMaskType.padding,
                        )
Masaki Kozuki's avatar
Masaki Kozuki committed
72
                    return
73
74
75
76
77
78
79
80
81
                fused_fn, torch_fn = self._setup_fused_softmax(
                    input_in_fp16,
                    input_in_bf16,
                    scale,
                    softmax_in_fp32,
                    AttnMaskType.padding,
                )

                attention_scores_0 = (
82
                    torch.randn(shape)
83
84
85
                    .to(device="cuda", dtype=dtype)
                    .requires_grad_(True)
                )
86
87
                with torch.no_grad():
                    attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
88
89
                mask_shape = (shape[0],) + (1,) + shape[2:]
                mask = torch.randint(0, 2, mask_shape, device="cuda").bool()
90
91
                expected = fused_fn(attention_scores_0, mask)
                actual = torch_fn(attention_scores_1, mask)
Aidyn-A's avatar
Aidyn-A committed
92
                self.assertEqual(actual, expected)
93
94
95
96
97
98

                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
99
100
101
102
103
104
105

    def test_autocast_fused_scale_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
106
107
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding
                )
108

109
110
111
                attention_scores_0 = (
                    torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                )
112
                with torch.no_grad():
113
114
115
                    attention_scores_1 = (
                        attention_scores_0.clone().to(dtype).requires_grad_(True)
                    )
Masaki Kozuki's avatar
Masaki Kozuki committed
116
117
                mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()

118
                expected = torch_fn(attention_scores_1, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
119
                with torch.cuda.amp.autocast(dtype=dtype):
120
                    actual = fused_fn(attention_scores_0, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
121
                    self.assertEqual(actual.dtype, dtype)
Aidyn-A's avatar
Aidyn-A committed
122
                self.assertEqual(actual, expected)
Masaki Kozuki's avatar
Masaki Kozuki committed
123

124
125
126
127
128
129
                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)

Masaki Kozuki's avatar
Masaki Kozuki committed
130
131
132
133
134
135
136
137
138
    def test_fused_upper_triangle_mask_softmax(self):
        """
        attn_weights.shape: [4, 12, 24, 24]
        total_mask.shape: [4, 1, 24, 24]

        total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
        upper elements are True and lower elements and diagonal are False.
        """
        for (dtype, scale, softmax_in_fp32) in itertools.product(
139
            (torch.half, torch.bfloat16), (None, 2.0), (False, True),
Masaki Kozuki's avatar
Masaki Kozuki committed
140
141
142
143
144
145
146
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
                        self._setup_fused_softmax(
147
148
149
150
151
152
                            input_in_fp16,
                            input_in_bf16,
                            scale,
                            softmax_in_fp32,
                            AttnMaskType.causal,
                        )
Masaki Kozuki's avatar
Masaki Kozuki committed
153
154
                    return
                fused_fn, torch_fn = self._setup_fused_softmax(
155
156
157
158
159
160
161
162
163
164
165
166
                    input_in_fp16,
                    input_in_bf16,
                    scale,
                    softmax_in_fp32,
                    AttnMaskType.causal,
                )

                attn_weights_0 = (
                    torch.randn((4, 12, 24, 24))
                    .to(device="cuda", dtype=dtype)
                    .requires_grad_(True)
                )
167
168
                with torch.no_grad():
                    attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
169
170
171
172
173
                total_mask = (
                    ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
                    .unsqueeze(0)
                    .unsqueeze(0)
                )
Masaki Kozuki's avatar
Masaki Kozuki committed
174
                total_mask = total_mask.repeat((4, 1, 1, 1))
175
176
                expected = fused_fn(attn_weights_0, total_mask)
                actual = torch_fn(attn_weights_1, total_mask)
Aidyn-A's avatar
Aidyn-A committed
177
                self.assertEqual(actual, expected)
178
179
180
181
182
183

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
184
185
186
187
188
189
190

    def test_autocast_fused_upper_triangle_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
191
192
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal
                )
193

194
195
196
                attn_weights_0 = (
                    torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                )
197
                with torch.no_grad():
198
199
200
201
202
203
204
205
                    attn_weights_1 = (
                        attn_weights_0.clone().to(dtype).requires_grad_(True)
                    )
                total_mask = (
                    ~(torch.tril(torch.randn((24, 24), device="cuda")).bool())
                    .unsqueeze(0)
                    .unsqueeze(0)
                )
Masaki Kozuki's avatar
Masaki Kozuki committed
206
207

                with torch.cuda.amp.autocast(dtype=dtype):
208
                    actual = fused_fn(attn_weights_0, total_mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
209
                    self.assertEqual(actual.dtype, dtype)
210
                expected = torch_fn(attn_weights_1, total_mask)
Aidyn-A's avatar
Aidyn-A committed
211
                self.assertEqual(actual, expected)
212
213
214
215
216
217

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)
Aidyn-A's avatar
Aidyn-A committed
218
219
220

if __name__ == "__main__":
    common_utils.run_tests()