fused_softmax.py 4.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

18
19

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
20
    """
21
22
23
24
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
25
    """
26

27
28
29
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
30

31
32
        scale_t = torch.tensor([scale])

33
34
35
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
36
37
38
39
40
41
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
42

43
44
        softmax_results, scale_t = ctx.saved_tensors

45
46
47
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
48
49
        return input_grads, None

50
51

class ScaledMaskedSoftmax(torch.autograd.Function):
52
    """
53
54
55
56
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
57
    """
58

59
60
61
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
62

63
64
        scale_t = torch.tensor([scale])

65
66
67
        softmax_results = scaled_masked_softmax_cuda.forward(
            inputs, mask, scale_t[0]
        )
68
69
70
71
72
73
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
74

75
76
        softmax_results, scale_t = ctx.saved_tensors

77
78
79
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
80
81
        return input_grads, None, None

82

83
84
class FusedScaleMaskSoftmax(torch.nn.Module):
    """
85
86
87
88
89
90
91
92
    fused operation: scaling + mask + softmax
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
        upper_triang_mask: if true, apply upper triangular masking.
                           (used in gpt family networks)
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
93
94

    """
95
96
97
98
99
100
101
102
103
104

    def __init__(
        self,
        input_in_fp16,
        upper_triang_mask_fusion,
        general_mask_fusion,
        mask_func,
        softmax_in_fp32,
        scale,
    ):
105
106
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
107
108
        self.upper_triang_mask_fusion = upper_triang_mask_fusion
        self.general_mask_fusion = general_mask_fusion
109
110
111
112
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

113
114
115
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
116
117
118
119

    def forward(self, input, mask):
        # [b, np, s, s]
        data_size = input.size()
120
        assert input.dim() == 4
121

122
        # invoke custom kernel
123
124
125
126
127
128
129
        if (
            self.input_in_fp16
            and data_size[-1] <= 2048
            and (self.upper_triang_mask_fusion or self.general_mask_fusion)
            and input.size()[2] == input.size()[3]
        ):
            scale = self.scale if self.scale is not None else 1.0
130
131
132
133
134
135
            if self.upper_triang_mask_fusion:
                input = input.view(-1, data_size[2], data_size[3])
                probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
                probs = probs.view(*data_size)
            else:
                probs = ScaledMaskedSoftmax.apply(input, mask, scale)
136
137
138
139
140
        else:
            if self.input_in_fp16 and self.softmax_in_fp32:
                input = input.float()

            if self.scale is not None:
141
                input = input * self.scale
142
            mask_output = self.mask_func(input, mask) if mask else input
143
144
145
146
147
148
            probs = torch.nn.Softmax(dim=-1)(mask_output)

            if self.input_in_fp16 and self.softmax_in_fp32:
                probs = probs.half()

        return probs