fused_softmax.py 6.74 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3

4
import torch
5
import torch.nn as nn
6
from megatron.model.enums import AttnMaskType
7

8
9

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
10
    """
11
12
13
14
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
15
    """
16

17
18
19
    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_upper_triang_masked_softmax_cuda
20

21
        scale_t = torch.tensor([scale])
22
23
24
        softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
            inputs, scale_t[0]
        )
25

26
27
28
29
30
31
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_upper_triang_masked_softmax_cuda
32

33
        softmax_results, scale_t = ctx.saved_tensors
34
35
36
        input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
37

38
39
        return input_grads, None

40
41

class ScaledMaskedSoftmax(torch.autograd.Function):
42
    """
43
44
45
46
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
47
    """
48

49
50
51
    @staticmethod
    def forward(ctx, inputs, mask, scale):
        import scaled_masked_softmax_cuda
52

53
54
        scale_t = torch.tensor([scale])

55
        softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
56
57
58
59
60
61
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_masked_softmax_cuda
62

63
64
        softmax_results, scale_t = ctx.saved_tensors

65
66
67
        input_grads = scaled_masked_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
68
69
        return input_grads, None, None

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs, scale):
        import scaled_softmax_cuda

        scale_t = torch.tensor([scale])

        softmax_results = scaled_softmax_cuda.forward(
            inputs, scale_t[0]
        )
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        import scaled_softmax_cuda

        softmax_results, scale_t = ctx.saved_tensors

        input_grads = scaled_softmax_cuda.backward(
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None


102
class FusedScaleMaskSoftmax(nn.Module):
103
    """
104
    fused operation: scaling + mask + softmax
105

106
107
    Arguments:
        input_in_fp16: flag to indicate if input in fp16 data format.
108
        input_in_bf16: flag to indicate if input in bf16 data format.
109
        attn_mask_type: attention mask type (pad or causal)
110
        scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
111
112
113
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
114
    """
115
116
117
118

    def __init__(
        self,
        input_in_fp16,
119
        input_in_bf16,
120
121
        attn_mask_type,
        scaled_masked_softmax_fusion,
122
123
124
125
        mask_func,
        softmax_in_fp32,
        scale,
    ):
126
127
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
128
        self.input_in_bf16 = input_in_bf16
129
130
131
        assert not (
            self.input_in_fp16 and self.input_in_bf16
        ), "both fp16 and bf16 flags cannot be active at the same time."
132
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
133
134
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
135
136
137
138
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale

139
140
141
        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"
142

143
    def forward(self, input, mask):
144
        # [b, np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
        assert input.dim() == 4
146
147
148

        if self.is_kernel_available(mask, *input.size()):
            return self.forward_fused_softmax(input, mask)
149
        else:
150
151
152
153
154
155
156
157
            return self.forward_torch_softmax(input, mask)

    def is_kernel_available(self, mask, b, np, sq, sk):
        attn_batches = b * np

        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
158
            and 16 < sk <= 4096  # sk must be 16 ~ 2048
hyunwoongko's avatar
hyunwoongko committed
159
            and sq % 4 == 0  # sq must be divisor of 4
160
            and sk % 4 == 0  # sk must be divisor of 4 
161
162
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
163
            if 0 <= sk <= 4096:
164
165
166
167
168
169
170
171
172
                batch_per_block = self.get_batch_per_block(sq, sk, b, np)

                if self.attn_mask_type == AttnMaskType.causal:
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False
173

174
175
176
    def forward_fused_softmax(self, input, mask):
        b, np, sq, sk = input.size()
        scale = self.scale if self.scale is not None else 1.0
177

178
179
180
181
182
183
184
185
186
        if self.attn_mask_type == AttnMaskType.causal:
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            input = input.view(-1, sq, sk)
            probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
            return probs.view(b, np, sq, sk)
        else:
            # input is 4D tensor (b, np, sq, sk)
187
188
189
190
            if mask is not None:
                return ScaledMaskedSoftmax.apply(input, mask, scale)
            else:
                return ScaledSoftmax.apply(input, scale)
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    def forward_torch_softmax(self, input, mask):
        if self.input_in_float16 and self.softmax_in_fp32:
            input = input.float()

        if self.scale is not None:
            input = input * self.scale
        mask_output = self.mask_func(input, mask) if mask is not None else input
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()
206
207

        return probs
208
209

    @staticmethod
hyunwoongko's avatar
hyunwoongko committed
210
    def get_batch_per_block(sq, sk, b, np):
211
212
213
        import scaled_masked_softmax_cuda

        return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)