fused_softmax.py 4.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
17
from megatron.model.enums import AttnMaskType
18

19
20

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
21
    """
22
23
24
25
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
26
    """
27

28
29
30
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
31

32
33
        scale_t = torch.tensor([scale])

34
35
36
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
37
38
39
40
41
42
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
43

44
45
        softmax_results, scale_t = ctx.saved_tensors

46
47
48
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
49
50
        return input_grads, None

51
52

class ScaledMaskedSoftmax(torch.autograd.Function):
53
    """
54
55
56
57
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
58
    """
59

60
61
62
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
63

64
65
        scale_t = torch.tensor([scale])

66
67
68
        softmax_results = scaled_masked_softmax_cuda.forward(
            inputs, mask, scale_t[0]
        )
69
70
71
72
73
74
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
75

76
77
        softmax_results, scale_t = ctx.saved_tensors

78
79
80
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
81
82
        return input_grads, None, None

83

84
85
class FusedScaleMaskSoftmax(torch.nn.Module):
    """
86
87
88
    fused operation: scaling + mask + softmax
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
89
        attn_mask_type: attention mask type (pad or causal)
90
91
92
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
93
94

    """
95
96
97
98

    def __init__(
        self,
        input_in_fp16,
99
100
        attn_mask_type,
        scaled_masked_softmax_fusion,
101
102
103
104
        mask_func,
        softmax_in_fp32,
        scale,
    ):
105
106
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
107
108
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
109
110
111
112
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

113
114
115
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
116
117

    def forward(self, input, mask):
118
        # [b, np, sq, sk]
119
        data_size = input.size()
120
121
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
122
        assert input.dim() == 4
123

124
        # invoke custom kernel
Vijay Korthikanti's avatar
Vijay Korthikanti committed
125
        if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
126
127
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:

128
            scale = self.scale if self.scale is not None else 1.0
129
130
131
132
133

            if self.attn_mask_type == AttnMaskType.causal:
                assert query_seq_len == key_seq_len, \
                    "causal mask is only for self attention"
                input = input.view(-1, query_seq_len, key_seq_len)
134
135
136
                probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
                probs = probs.view(*data_size)
            else:
137
                assert self.attn_mask_type == AttnMaskType.padding
138
                probs = ScaledMaskedSoftmax.apply(input, mask, scale)
139
140
141
142
143
        else:
            if self.input_in_fp16 and self.softmax_in_fp32:
                input = input.float()

            if self.scale is not None:
144
                input = input * self.scale
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
            mask_output = self.mask_func(input, mask) if mask is not None else input
146
147
148
149
150
151
            probs = torch.nn.Softmax(dim=-1)(mask_output)

            if self.input_in_fp16 and self.softmax_in_fp32:
                probs = probs.half()

        return probs