fused_softmax.py 5.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
17
from megatron.model.enums import AttnMaskType
18

19
20

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
21
    """
22
23
24
25
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
26
    """
27

28
29
30
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
31

32
33
        scale_t = torch.tensor([scale])

34
35
36
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
37
38
39
40
41
42
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
43

44
45
        softmax_results, scale_t = ctx.saved_tensors

46
47
48
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
49
50
        return input_grads, None

51
52

class ScaledMaskedSoftmax(torch.autograd.Function):
53
    """
54
55
56
57
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
58
    """
59

60
61
62
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
63

64
65
        scale_t = torch.tensor([scale])

66
67
68
        softmax_results = scaled_masked_softmax_cuda.forward(
            inputs, mask, scale_t[0]
        )
69
70
71
72
73
74
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
75

76
77
        softmax_results, scale_t = ctx.saved_tensors

78
79
80
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
81
82
        return input_grads, None, None

83

84
85
class FusedScaleMaskSoftmax(torch.nn.Module):
    """
86
87
88
    fused operation: scaling + mask + softmax
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
89
        attn_mask_type: attention mask type (pad or causal)
90
91
92
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
93
94

    """
95
96
97
98

    def __init__(
        self,
        input_in_fp16,
99
100
        attn_mask_type,
        scaled_masked_softmax_fusion,
101
102
103
104
        mask_func,
        softmax_in_fp32,
        scale,
    ):
105
106
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
107
108
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
109
110
111
112
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

113
114
115
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
116
 
117
    def forward(self, input, mask):
118
        # [b, np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
119
        assert input.dim() == 4
120
        data_size = input.size()
121
122
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
123
        attn_batch_size = data_size[0] * data_size[1]
124

Vijay Korthikanti's avatar
Vijay Korthikanti committed
125
126
127
128
        # constraints on various tensor dimensions to enable warp based
        # optimization and upper triangular optimization (for causal mask)
        custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
129

Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
131
132
        # invoke custom kernel
        if self.input_in_fp16 and mask is not None and \
           custom_kernel_constraint and self.scaled_masked_softmax_fusion:
133
            scale = self.scale if self.scale is not None else 1.0
134
135
136
137
138

            if self.attn_mask_type == AttnMaskType.causal:
                assert query_seq_len == key_seq_len, \
                    "causal mask is only for self attention"
                input = input.view(-1, query_seq_len, key_seq_len)
139
140
141
                probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
                probs = probs.view(*data_size)
            else:
142
                assert self.attn_mask_type == AttnMaskType.padding
143
                probs = ScaledMaskedSoftmax.apply(input, mask, scale)
144
145
146
147
148
        else:
            if self.input_in_fp16 and self.softmax_in_fp32:
                input = input.float()

            if self.scale is not None:
149
                input = input * self.scale
Vijay Korthikanti's avatar
Vijay Korthikanti committed
150
            mask_output = self.mask_func(input, mask) if mask is not None else input
151
152
153
154
155
156
            probs = torch.nn.Softmax(dim=-1)(mask_output)

            if self.input_in_fp16 and self.softmax_in_fp32:
                probs = probs.half()

        return probs