fused_softmax.py 5.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
17
from megatron.model.enums import AttnMaskType
18

19
20

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
21
    """
22
23
24
25
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
26
    """
27

28
29
30
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
31

32
33
        scale_t = torch.tensor([scale])

34
35
36
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
37
38
39
40
41
42
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
43

44
45
        softmax_results, scale_t = ctx.saved_tensors

46
47
48
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
49
50
        return input_grads, None

51
52

class ScaledMaskedSoftmax(torch.autograd.Function):
53
    """
54
55
56
57
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
58
    """
59

60
61
62
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
63

64
65
        scale_t = torch.tensor([scale])

66
67
68
        softmax_results = scaled_masked_softmax_cuda.forward(
            inputs, mask, scale_t[0]
        )
69
70
71
72
73
74
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
75

76
77
        softmax_results, scale_t = ctx.saved_tensors

78
79
80
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
81
82
        return input_grads, None, None

83

84
85
class FusedScaleMaskSoftmax(torch.nn.Module):
    """
86
87
88
    fused operation: scaling + mask + softmax
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
89
        attn_mask_type: attention mask type (pad or causal)
90
91
92
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
93
94

    """
95
96
97
98

    def __init__(
        self,
        input_in_fp16,
99
        input_in_bf16,
100
101
        attn_mask_type,
        scaled_masked_softmax_fusion,
102
103
104
105
        mask_func,
        softmax_in_fp32,
        scale,
    ):
106
107
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
108
109
110
111
        self.input_in_bf16 = input_in_bf16
        assert not (self.input_in_fp16 and self.input_in_bf16),\
            'both fp16 and bf16 flags cannot be active at the same time.'
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
112
113
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
114
115
116
117
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

118
119
120
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
Vijay Korthikanti's avatar
Vijay Korthikanti committed
121
 
122
    def forward(self, input, mask):
123
        # [b, np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
124
        assert input.dim() == 4
125
        data_size = input.size()
126
127
        query_seq_len = data_size[-2]
        key_seq_len = data_size[-1]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
128
        attn_batch_size = data_size[0] * data_size[1]
129

Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
131
132
133
        # constraints on various tensor dimensions to enable warp based
        # optimization and upper triangular optimization (for causal mask)
        custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
            query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
134

Vijay Korthikanti's avatar
Vijay Korthikanti committed
135
        # invoke custom kernel
136
137
        if self.input_in_float16 and mask is not None and \
            custom_kernel_constraint and self.scaled_masked_softmax_fusion:
138
            scale = self.scale if self.scale is not None else 1.0
139
140
141
142
143

            if self.attn_mask_type == AttnMaskType.causal:
                assert query_seq_len == key_seq_len, \
                    "causal mask is only for self attention"
                input = input.view(-1, query_seq_len, key_seq_len)
144
145
146
                probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
                probs = probs.view(*data_size)
            else:
147
                assert self.attn_mask_type == AttnMaskType.padding
148
                probs = ScaledMaskedSoftmax.apply(input, mask, scale)
149
        else:
150
            if self.input_in_float16 and self.softmax_in_fp32:
151
152
153
                input = input.float()

            if self.scale is not None:
154
                input = input * self.scale
Vijay Korthikanti's avatar
Vijay Korthikanti committed
155
            mask_output = self.mask_func(input, mask) if mask is not None else input
156
157
            probs = torch.nn.Softmax(dim=-1)(mask_output)

158
159
160
161
162
            if self.input_in_float16 and self.softmax_in_fp32:
                if self.input_in_fp16:
                    probs = probs.half()
                else:
                    probs = probs.bfloat16()
163
164

        return probs