fused_softmax.py 7.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16

17
import torch
18
import torch.nn as nn
19
from megatron.model.enums import AttnMaskType
20

21
22

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
23
    """
24
25
26
27
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
28
    """
29

30
31
32
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
33

34
        scale_t = torch.tensor([scale])
35
36
37
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
38

39
40
41
42
43
44
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
45

46
        softmax_results, scale_t = ctx.saved_tensors
47
48
49
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
50

51
52
        return input_grads, None

53
54

class ScaledMaskedSoftmax(torch.autograd.Function):
55
    """
56
57
58
59
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
60
    """
61

62
63
64
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
65

66
67
        scale_t = torch.tensor([scale])

68
        softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
69
70
71
72
73
74
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
75

76
77
        softmax_results, scale_t = ctx.saved_tensors

78
79
80
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
81
82
        return input_grads, None, None

83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_softmax_cuda

        scale_t = torch.tensor([scale])

        softmax_results = scaled_softmax_cuda.forward(
            inputs, scale_t[0]
        )
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_softmax_cuda

        softmax_results, scale_t = ctx.saved_tensors

        input_grads = scaled_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None


115
class FusedScaleMaskSoftmax(nn.Module):
116
    """
117
    fused operation: scaling + mask + softmax
118

119
120
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
121
        input_in_bf16: flag to indicate if input in bf16 data format.
122
        attn_mask_type: attention mask type (pad or causal)
123
        scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
124
125
126
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
127
    """
128
129
130
131

    def __init__(
        self,
        input_in_fp16,
132
        input_in_bf16,
133
134
        attn_mask_type,
        scaled_masked_softmax_fusion,
135
136
137
138
        mask_func,
        softmax_in_fp32,
        scale,
    ):
139
140
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
141
        self.input_in_bf16 = input_in_bf16
142
143
144
        assert not (
            self.input_in_fp16 and self.input_in_bf16
        ), "both fp16 and bf16 flags cannot be active at the same time."
145
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
146
147
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
148
149
150
151
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

152
153
154
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
155

156
    def forward(self, input, mask):
157
        # [b, np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
158
        assert input.dim() == 4
159
160
161

        if self.is_kernel_available(mask, *input.size()):
            return self.forward_fused_softmax(input, mask)
162
        else:
163
164
165
166
167
168
169
170
            return self.forward_torch_softmax(input, mask)

    def is_kernel_available(self, mask, b, np, sq, sk):
        attn_batches = b * np

        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
171
            and 16 < sk <= 4096  # sk must be 16 ~ 2048
hyunwoongko's avatar
hyunwoongko committed
172
            and sq % 4 == 0  # sq must be divisor of 4
173
174
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
175
            if 0 <= sk <= 4096:
176
177
178
179
180
181
182
183
184
                batch_per_block = self.get_batch_per_block(sq, sk, b, np)

                if self.attn_mask_type == AttnMaskType.causal:
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False
185

186
187
188
    def forward_fused_softmax(self, input, mask):
        b, np, sq, sk = input.size()
        scale = self.scale if self.scale is not None else 1.0
189

190
191
192
193
194
195
196
197
198
        if self.attn_mask_type == AttnMaskType.causal:
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            input = input.view(-1, sq, sk)
            probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
            return probs.view(b, np, sq, sk)
        else:
            # input is 4D tensor (b, np, sq, sk)
199
200
201
202
            if mask is not None:
                return ScaledMaskedSoftmax.apply(input, mask, scale)
            else:
                return ScaledSoftmax.apply(input, scale)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    def forward_torch_softmax(self, input, mask):
        if self.input_in_float16 and self.softmax_in_fp32:
            input = input.float()

        if self.scale is not None:
            input = input * self.scale
        mask_output = self.mask_func(input, mask) if mask is not None else input
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()
218
219

        return probs
220
221

    @staticmethod
hyunwoongko's avatar
hyunwoongko committed
222
    def get_batch_per_block(sq, sk, b, np):
223
224
225
        import scaled_masked_softmax_cuda

        return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)