fused_softmax.py 7.64 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3

4
import torch
5
import torch.nn as nn
xingjinliang's avatar
xingjinliang committed
6
from megatron.legacy.model.enums import AttnMaskType
7

8
9

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
10
    """
11
12
13
14
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
15
    """
16

17
18
    @staticmethod
    def forward(ctx, inputs, scale):
xingjinliang's avatar
xingjinliang committed
19
20
21
22
        try:
            import scaled_upper_triang_masked_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
23

24
        scale_t = torch.tensor([scale])
25
26
27
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
28

29
30
31
32
33
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
xingjinliang's avatar
xingjinliang committed
34
35
36
37
        try:
            import scaled_upper_triang_masked_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
38

39
        softmax_results, scale_t = ctx.saved_tensors
40
41
42
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
43

44
45
        return input_grads, None

46
47

class ScaledMaskedSoftmax(torch.autograd.Function):
48
    """
49
50
51
52
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
53
    """
54

55
56
    @staticmethod
    def forward(ctx, inputs, mask, scale):
xingjinliang's avatar
xingjinliang committed
57
58
59
60
        try:
            import scaled_masked_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
61

62
63
        scale_t = torch.tensor([scale])

64
        softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
65
66
67
68
69
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
xingjinliang's avatar
xingjinliang committed
70
71
72
73
        try:
            import scaled_masked_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
74

75
76
        softmax_results, scale_t = ctx.saved_tensors

77
78
79
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
80
81
        return input_grads, None, None

82

83
84
85
86
87
88
89
90
91
class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs, scale):
xingjinliang's avatar
xingjinliang committed
92
93
94
95
        try:
            import scaled_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
96
97
98
99
100
101
102
103
104
105
106

        scale_t = torch.tensor([scale])

        softmax_results = scaled_softmax_cuda.forward(
            inputs, scale_t[0]
        )
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
xingjinliang's avatar
xingjinliang committed
107
108
109
110
        try:
            import scaled_softmax_cudaa
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
111
112
113
114
115
116
117
118
119

        softmax_results, scale_t = ctx.saved_tensors

        input_grads = scaled_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None


120
class FusedScaleMaskSoftmax(nn.Module):
121
    """
122
    fused operation: scaling + mask + softmax
123

xingjinliang's avatar
xingjinliang committed
124
    Args:
125
        input_in_fp16: flag to indicate if input in fp16 data format.
126
        input_in_bf16: flag to indicate if input in bf16 data format.
127
        attn_mask_type: attention mask type (pad or causal)
128
        scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
129
130
131
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
132
    """
133
134
135
136

    def __init__(
        self,
        input_in_fp16,
137
        input_in_bf16,
138
139
        attn_mask_type,
        scaled_masked_softmax_fusion,
140
141
142
143
        mask_func,
        softmax_in_fp32,
        scale,
    ):
144
145
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
146
        self.input_in_bf16 = input_in_bf16
147
148
149
        assert not (
            self.input_in_fp16 and self.input_in_bf16
        ), "both fp16 and bf16 flags cannot be active at the same time."
150
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
151
152
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
153
154
155
156
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

157
158
159
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
160

161
    def forward(self, input, mask):
162
        # [b, np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
163
        assert input.dim() == 4
164
165
166

        if self.is_kernel_available(mask, *input.size()):
            return self.forward_fused_softmax(input, mask)
167
        else:
168
169
170
171
172
173
174
175
            return self.forward_torch_softmax(input, mask)

    def is_kernel_available(self, mask, b, np, sq, sk):
        attn_batches = b * np

        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
liangjing's avatar
v1  
liangjing committed
176
            and 16 < sk <= 16384  # sk must be 16 ~ 16384
hyunwoongko's avatar
hyunwoongko committed
177
            and sq % 4 == 0  # sq must be divisor of 4
liangjing's avatar
v1  
liangjing committed
178
            and sk % 4 == 0  # sk must be divisor of 4
179
180
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
liangjing's avatar
v1  
liangjing committed
181
            if 0 <= sk <= 16384:
182
183
184
185
186
187
188
189
190
                batch_per_block = self.get_batch_per_block(sq, sk, b, np)

                if self.attn_mask_type == AttnMaskType.causal:
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False
191

192
193
194
    def forward_fused_softmax(self, input, mask):
        b, np, sq, sk = input.size()
        scale = self.scale if self.scale is not None else 1.0
195

196
197
198
199
200
201
202
203
204
        if self.attn_mask_type == AttnMaskType.causal:
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            input = input.view(-1, sq, sk)
            probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
            return probs.view(b, np, sq, sk)
        else:
            # input is 4D tensor (b, np, sq, sk)
205
206
207
208
            if mask is not None:
                return ScaledMaskedSoftmax.apply(input, mask, scale)
            else:
                return ScaledSoftmax.apply(input, scale)
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    def forward_torch_softmax(self, input, mask):
        if self.input_in_float16 and self.softmax_in_fp32:
            input = input.float()

        if self.scale is not None:
            input = input * self.scale
        mask_output = self.mask_func(input, mask) if mask is not None else input
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()
224
225

        return probs
226
227

    @staticmethod
hyunwoongko's avatar
hyunwoongko committed
228
    def get_batch_per_block(sq, sk, b, np):
xingjinliang's avatar
xingjinliang committed
229
230
231
232
        try:
            import scaled_masked_softmax_cuda
        except (ImportError, ModuleNotFoundError):
            print(f'Please install Apex to use fused_softmax')
233
234

        return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)