test_fused_softmax.py 7.38 KB
Newer Older
Masaki Kozuki's avatar
Masaki Kozuki committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""Test for fused softmax functions.

Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py
"""  # NOQA
import itertools
import unittest

import torch

from apex.transformer import AttnMaskType
from apex.transformer.functional import FusedScaleMaskSoftmax


def attention_mask_func(attention_scores, attention_mask):
15
    return attention_scores.masked_fill(attention_mask, -10000.0)
Masaki Kozuki's avatar
Masaki Kozuki committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)


class TestFusedScaleMaskSoftmax(unittest.TestCase):

    def _setup_fused_softmax(self, input_in_fp16, input_in_bf16, scale=None, softmax_in_fp32=False, attn_mask_type=AttnMaskType.padding):
        fused_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=True,
        )
        torch_fn = FusedScaleMaskSoftmax(
            input_in_fp16=input_in_fp16,
            input_in_bf16=input_in_bf16,
            mask_func=attention_mask_func,
            scale=scale,
            softmax_in_fp32=softmax_in_fp32,
            attn_mask_type=attn_mask_type,
            scaled_masked_softmax_fusion=False,
        )
        return fused_fn, torch_fn

    def test_fused_scale_mask_softmax(self):
        """
        attention_scores.shape = [4, 12, 24, 24]
        mask.shape = [4, 1, 24, 24]
        """
        for (dtype, scale, softmax_in_fp32) in itertools.product(
                (torch.half, torch.bfloat16),
                (None, 2.0),
                (False, True),
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
                        self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)
                    return
                fused_fn, torch_fn = self._setup_fused_softmax(input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.padding)

63
64
65
                attention_scores_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
                with torch.no_grad():
                    attention_scores_1 = attention_scores_0.clone().requires_grad_(True)
Masaki Kozuki's avatar
Masaki Kozuki committed
66
                mask = torch.randint(0, 2, (4, 1, 24, 24), device="cuda").bool()
67
68
69
70
71
72
73
74
75
                expected = fused_fn(attention_scores_0, mask)
                actual = torch_fn(attention_scores_1, mask)
                torch.testing.assert_allclose(actual, expected)

                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
76
77
78
79
80
81
82
83

    def test_autocast_fused_scale_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding)
84
85
86
87

                attention_scores_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                with torch.no_grad():
                    attention_scores_1 = attention_scores_0.clone().to(dtype).requires_grad_(True)
Masaki Kozuki's avatar
Masaki Kozuki committed
88
89
                mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda()

90
                expected = torch_fn(attention_scores_1, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
91
                with torch.cuda.amp.autocast(dtype=dtype):
92
                    actual = fused_fn(attention_scores_0, mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
93
94
95
                    self.assertEqual(actual.dtype, dtype)
                torch.testing.assert_allclose(actual, expected)

96
97
98
99
100
101
                g0 = torch.rand_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                expected.backward(g0)
                actual.backward(g1)

Masaki Kozuki's avatar
Masaki Kozuki committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    def test_fused_upper_triangle_mask_softmax(self):
        """
        attn_weights.shape: [4, 12, 24, 24]
        total_mask.shape: [4, 1, 24, 24]

        total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but
        upper elements are True and lower elements and diagonal are False.
        """
        for (dtype, scale, softmax_in_fp32) in itertools.product(
                (torch.half, torch.bfloat16),
                (None, 2.0),
                (False, True),
        ):
            with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                if not (scale is None or softmax_in_fp32):
                    with self.assertRaises(RuntimeError):
                        self._setup_fused_softmax(
                            input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)
                    return
                fused_fn, torch_fn = self._setup_fused_softmax(
                    input_in_fp16, input_in_bf16, scale, softmax_in_fp32, AttnMaskType.causal)

126
127
128
                attn_weights_0 = torch.randn((4, 12, 24, 24)).to(device="cuda", dtype=dtype).requires_grad_(True)
                with torch.no_grad():
                    attn_weights_1 = attn_weights_0.clone().requires_grad_(True)
Masaki Kozuki's avatar
Masaki Kozuki committed
129
130
131
132
                total_mask = (~(
                    torch.tril(torch.randn((24, 24), device="cuda")).bool()
                ).unsqueeze(0).unsqueeze(0))
                total_mask = total_mask.repeat((4, 1, 1, 1))
133
134
135
136
137
138
139
140
141
                expected = fused_fn(attn_weights_0, total_mask)
                actual = torch_fn(attn_weights_1, total_mask)
                torch.testing.assert_allclose(actual, expected)

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)
Masaki Kozuki's avatar
Masaki Kozuki committed
142
143
144
145
146
147
148
149

    def test_autocast_fused_upper_triangle_mask_softmax(self):
        for dtype in autocast_dtypes:
            with self.subTest(f"{dtype}"):
                input_in_fp16 = dtype == torch.half
                input_in_bf16 = dtype == torch.bfloat16
                fused_fn, torch_fn = self._setup_fused_softmax(
                    input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal)
150
151
152
153

                attn_weights_0 = torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True)
                with torch.no_grad():
                    attn_weights_1 = attn_weights_0.clone().to(dtype).requires_grad_(True)
Masaki Kozuki's avatar
Masaki Kozuki committed
154
155
156
157
158
                total_mask = (~(
                    torch.tril(torch.randn((24, 24), device="cuda")).bool()
                ).unsqueeze(0).unsqueeze(0))

                with torch.cuda.amp.autocast(dtype=dtype):
159
                    actual = fused_fn(attn_weights_0, total_mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
160
                    self.assertEqual(actual.dtype, dtype)
161
                expected = torch_fn(attn_weights_1, total_mask)
Masaki Kozuki's avatar
Masaki Kozuki committed
162
                torch.testing.assert_allclose(actual, expected)
163
164
165
166
167
168

                g0 = torch.randn_like(actual)
                with torch.no_grad():
                    g1 = g0.clone()
                actual.backward(g0)
                expected.backward(g1)