scaled_softmax.py 6.43 KB
Newer Older
ver217's avatar
ver217 committed
1
2
3
# This code from NVIDIA Megatron:
#     with minor changes.

4
5
import enum

shenggan's avatar
shenggan committed
6
7
8
import torch
import torch.nn as nn

9
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
10
11
12
13
14
15
16

try:
    from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
except ImportError:
    scaled_masked_softmax = None
    scaled_upper_triang_masked_softmax = None

shenggan's avatar
shenggan committed
17
18
19
20

class AttnMaskType(enum.Enum):
    padding = 1
    causal = 2
21
    paddedcausal = 3
shenggan's avatar
shenggan committed
22
23
24
25
26


class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
HELSON's avatar
HELSON committed
27
28
29
30

        1.  Scale the tensor.
        2.  Apply upper triangular mask (typically used in gpt models).
        3.  Perform softmax.
shenggan's avatar
shenggan committed
31
32
33
34
    """

    @staticmethod
    def forward(ctx, inputs, scale):
35
36
        global scaled_upper_triang_masked_softmax
        if scaled_upper_triang_masked_softmax:
37
            scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
shenggan's avatar
shenggan committed
38
39

        scale_t = torch.tensor([scale])
40
        softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
shenggan's avatar
shenggan committed
41
42
43
44
45
46
47

        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        softmax_results, scale_t = ctx.saved_tensors
48
        input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
shenggan's avatar
shenggan committed
49
50
51
52
53
54
55

        return input_grads, None


class ScaledMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
HELSON's avatar
HELSON committed
56
57
58
59

        1.  Scale the tensor.
        2.  Apply the mask.
        3.  Perform softmax.
shenggan's avatar
shenggan committed
60
61
62
63
64
65
    """

    @staticmethod
    def forward(ctx, inputs, mask, scale):
        scale_t = torch.tensor([scale])

66
67
68
        # build and load kernel if not pre-built
        global scaled_masked_softmax
        if scaled_masked_softmax is None:
69
            scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
70

jiaruifang's avatar
jiaruifang committed
71
        softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
shenggan's avatar
shenggan committed
72
73
74
75
76
77
78
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(ctx, output_grads):
        softmax_results, scale_t = ctx.saved_tensors

jiaruifang's avatar
jiaruifang committed
79
        input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
80
        return input_grads, None, None, None
shenggan's avatar
shenggan committed
81
82
83
84


class FusedScaleMaskSoftmax(nn.Module):
    """
HELSON's avatar
HELSON committed
85
    Fused operation: scaling + mask + softmax
shenggan's avatar
shenggan committed
86
87

    Arguments:
HELSON's avatar
HELSON committed
88
89
90
91
92
93
94
        input_in_fp16: Flag to indicate if input in fp16 data format.
        input_in_bf16: Flag to indicate if input in bf16 data format.
        attn_mask_type: Attention mask type (pad or causal)
        scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
        mask_func: Mask function to be applied.
        softmax_in_fp32: If True, softmax in performed at fp32 precision.
        scale: Scaling factor used in input tensor scaling.
shenggan's avatar
shenggan committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    """

    def __init__(
        self,
        input_in_fp16,
        input_in_bf16,
        attn_mask_type,
        scaled_masked_softmax_fusion,
        mask_func,
        softmax_in_fp32,
        scale,
    ):
        super(FusedScaleMaskSoftmax, self).__init__()
        self.input_in_fp16 = input_in_fp16
        self.input_in_bf16 = input_in_bf16
110
111
112
        assert not (
            self.input_in_fp16 and self.input_in_bf16
        ), "both fp16 and bf16 flags cannot be active at the same time."
shenggan's avatar
shenggan committed
113
114
115
116
117
118
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale
119
        assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
shenggan's avatar
shenggan committed
120
121
122
123
124
125
126
127
128
129
130
131
132

    def forward(self, input, mask):
        # [b, np, sq, sk]
        assert input.dim() == 4

        if self.is_kernel_available(mask, *input.size()):
            return self.forward_fused_softmax(input, mask)
        else:
            return self.forward_torch_softmax(input, mask)

    def is_kernel_available(self, mask, b, np, sq, sk):
        attn_batches = b * np

133
134
135
136
137
138
139
140
        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
            and mask is not None  # mask tensor must not be None
            and 16 < sk <= 2048  # sk must be 16 ~ 2048
            and sq % 4 == 0  # sq must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
shenggan's avatar
shenggan committed
141
142
143
            if 0 <= sk <= 2048:
                batch_per_block = self.get_batch_per_block(sq, sk, b, np)

144
                if self.attn_mask_type.value > 1:
shenggan's avatar
shenggan committed
145
146
147
148
149
150
151
152
153
154
155
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False

    def forward_fused_softmax(self, input, mask):
        b, np, sq, sk = input.size()
        scale = self.scale if self.scale is not None else 1.0

156
        if self.attn_mask_type.value > 1:
shenggan's avatar
shenggan committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            input = input.view(-1, sq, sk)
            probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
            return probs.view(b, np, sq, sk)
        else:
            # input is 4D tensor (b, np, sq, sk)
            return ScaledMaskedSoftmax.apply(input, mask, scale)

    def forward_torch_softmax(self, input, mask):
        if self.input_in_float16 and self.softmax_in_fp32:
            input = input.float()

        if self.scale is not None:
            input = input * self.scale
        mask_output = self.mask_func(input, mask) if mask is not None else input
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()

        return probs

184
    def get_batch_per_block(self, sq, sk, b, np):
185
186
187
188
189
        # build and load kernel if not pre-built
        global scaled_masked_softmax
        if scaled_masked_softmax is None:
            scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()

190
        return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)