fused_softmax.py 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
17
from megatron.model.enums import AttnMaskType
18

19
20

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
21
22
    """
       Fused operation which performs following three operations in sequence
23
       1. Scale the tensor.
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
       2. Apply upper triangular mask (typically used in gpt models).
       3. Perform softmax.
    """
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
        scale_t = torch.tensor([scale])

        softmax_results =  \
            scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
        softmax_results, scale_t = ctx.saved_tensors

        input_grads =   \
43
44
45
            scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
                                                             softmax_results,
                                                             scale_t[0])
46
47
        return input_grads, None

48
49

class ScaledMaskedSoftmax(torch.autograd.Function):
50
51
    """
       Fused operation which performs following three operations in sequence
52
       1. Scale the tensor.
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
       2. Apply the mask.
       3. Perform softmax.
    """
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
        scale_t = torch.tensor([scale])

        softmax_results =  \
            scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
        softmax_results, scale_t = ctx.saved_tensors

        input_grads =   \
            scaled_masked_softmax_cuda.backward(output_grads,
                                                softmax_results,
                                                scale_t[0])
        return input_grads, None, None

77

78
79
80
81
82
class FusedScaleMaskSoftmax(torch.nn.Module):
    """
       fused operation: scaling + mask + softmax
       Arguments:
           input_in_fp16: flag to indicate if input in fp16 data format.
83
           attn_mask_type: attention mask type (pad or causal)
84
85
86
87
88
           mask_func: mask function to be applied.
           softmax_in_fp32: if true, softmax in performed at fp32 precision.
           scale: scaling factor used in input tensor scaling.

    """
89
90
91
    def __init__(self, input_in_fp16, attn_mask_type,
                 scaled_masked_softmax_fusion, mask_func,
                 softmax_in_fp32, scale):
92
93
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
94
95
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
96
97
98
99
100
101
102
103
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

        assert self.scale is None or softmax_in_fp32, \
            'softmax should be in fp32 when scaled'

    def forward(self, input, mask):
104
        # [b, np, sq, sk]
105
        data_size = input.size()
106
107
108
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
        assert input.dim() == 4
109

110
        # invoke custom kernel
111
112
113
114
115
116
117
118
119
        if self.input_in_fp16 and key_seq_len <= 2048 and \
           query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:

            scale = self.scale if self.scale is not None else 1.0

            if self.attn_mask_type == AttnMaskType.causal:
                assert query_seq_len == key_seq_len, \
                    "causal mask is only for self attention"
                input = input.view(-1, query_seq_len, key_seq_len)
120
121
122
                probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
                probs = probs.view(*data_size)
            else:
123
                assert self.attn_mask_type == AttnMaskType.padding
124
                probs = ScaledMaskedSoftmax.apply(input, mask, scale)
125
126
127
128
129
        else:
            if self.input_in_fp16 and self.softmax_in_fp32:
                input = input.float()

            if self.scale is not None:
130
131
                input = input * self.scale
            mask_output = self.mask_func(input, mask)
132
133
134
135
136
137
            probs = torch.nn.Softmax(dim=-1)(mask_output)

            if self.input_in_fp16 and self.softmax_in_fp32:
                probs = probs.half()

        return probs