softmax.py 12.7 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Fused scaled masked softmax functions"""
import os
7
from typing import Callable, Tuple, Union, Optional
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import torch
from torch import nn
10
11
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
12
import transformer_engine_extensions as tex
13
from transformer_engine.pytorch.export import is_in_onnx_export_mode
Neta Zmora's avatar
Neta Zmora committed
14
15
from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32

16
17
18
19

THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128

Przemek Tredak's avatar
Przemek Tredak committed
20

21
22
23
24
25
26
27
28
29
_default_causal_mask = {}

def _get_default_causal_mask(sq: int) -> torch.Tensor:
    """Return the causal upper triangular mask for softmax input"""
    if sq not in _default_causal_mask:
        _default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
    return _default_causal_mask[sq]


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def _get_onnx_export_causal_mask(
    seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor
) -> torch.Tensor:
    """Return the causal upper triangular mask for softmax input, for ONNX export.

    ONNX does not support dynamic control-flow and requires non-square masks when
    using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1).

    Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct
    shape for GPT context and generation phases.
    In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in
    the generation phase the mask is rectangular with shape (1, seq_k).
    """
    assert len(onnx_causal_mask.size()) == 2
    assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1)
    assert onnx_causal_mask.size(0) >= (seq_k-seq_q) >= 0
    derived_mask = onnx_causal_mask[seq_k-seq_q:seq_k, :seq_k]
    return derived_mask


Neta Zmora's avatar
Neta Zmora committed
50
51
52
53
54
55
56
def fp32_compute(onnx_symbolic_fn):
    """A decorator that wraps an ONNX symoblic function with FP32 compute operators."""
    def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs):
        return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs)
    return wrapper


Przemek Tredak's avatar
Przemek Tredak committed
57
58
59
60
61
62
63
64
65
66
67
68
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledUpperTriangMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])
69
        softmax_results = tex.scaled_upper_triang_masked_softmax_forward(
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
73
74
75
76
77
78
79
80
81
            inputs, scale_t[0]
        )

        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledUpperTriangMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors
82
        input_grads = tex.scaled_upper_triang_masked_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
83
84
85
86
87
            output_grads, softmax_results, scale_t[0]
        )

        return input_grads, None

88
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
89
    @fp32_compute
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
        """ScaledUpperTriangMaskedSoftmax symbolic method"""
        def triangular_mask():
            dtype =  _type_utils.JitScalarType.INT64
            ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
            k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
            mask = g.op("Trilu", ones, k, upper_i=1)
            mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
            return mask

        # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
        mask = triangular_mask()
        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
        inv_mask = g.op("Sub", one, mask)

        neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
        softmax_mask = g.op("Mul", mask, neg_tenK)

        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        masked_scaled = g.op("Mul", inv_mask, scaled)
        masked = g.op("Add", masked_scaled, softmax_mask)
        out = g.op("Softmax", masked)
        return out

Przemek Tredak's avatar
Przemek Tredak committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

class ScaledMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
    """

    @staticmethod
    def forward(
        ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float
    ) -> torch.Tensor:
        """ScaledMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])

131
        softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
132
133
134
135
136
137
138
139
140
141
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

142
        input_grads = tex.scaled_masked_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
143
144
145
146
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None

147
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
148
    @fp32_compute
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    def symbolic(
        g: torch.Graph,
        inputs: torch._C.Value,
        mask: torch._C.Value,
        scale: float) -> torch._C.Value:
        """ScaledMaskedSoftmax symbolic method"""
        # Captures the logic of function scaled_masked_softmax_warp_forward.
        # output = softmax(mask(input*scale)
        # Computed as:
        #   masked_scaled = (1 - mask)*(input*scale)
        #   softmax_mask = mask * -10000
        #   output = softmax(masked_scaled + softmax_mask)
        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
        inv_mask = g.op("Sub", one, mask)
        # Note: type is hard coded because softmax uses FP16 or BF16
        neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
        softmax_mask = g.op("Mul", mask, neg_tenK)
        masked_scaled = g.op("Mul", inv_mask, scaled)
        masked = g.op("Add", masked_scaled, softmax_mask)
        out = g.op("Softmax", masked)
        return out

Przemek Tredak's avatar
Przemek Tredak committed
173
174
175
176
177
178
179
180
181
182
183
184
185

class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledSoftmax fwd"""
        scale_t = torch.tensor([scale])

186
        softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
187
188
189
190
191
192
193
194
195
196
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

197
        input_grads = tex.scaled_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
198
199
200
201
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None

202
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
203
    @fp32_compute
204
205
206
207
208
209
210
211
    def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
        """ScaledSoftmax symbolic method"""
        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        out = g.op("Softmax", scaled)
        return out


Przemek Tredak's avatar
Przemek Tredak committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

class FusedScaleMaskSoftmax(nn.Module):
    """
    fused operation: scaling + mask + softmax

    Arguments:
        attn_mask_type: attention mask type (pad or causal)
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
    """

    def __init__(
        self,
        attn_mask_type: str,
        mask_func: Callable,
227
        softmax_in_fp32: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
228
229
230
231
232
233
234
235
236
    ) -> None:
        super().__init__()
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = bool(
            int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
        )
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32

237
238
239
240
241
242
243
244
245
246
247
        # Users exporting to ONNX can optimize the attention mask for GPT text generation.
        self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1"))
        if self.kvcache_max_seq > 0:
            self.register_buffer(
                "onnx_causal_mask",
                torch.triu(
                    torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"),
                    diagonal=1
                ).bool(),
                persistent=False)

248
249
250
251
252
253
    def forward(
        self,
        inp: torch.Tensor,
        mask: torch.Tensor,
        scale: Optional[float] = None,
    ) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
254
255
256
257
258
259
260
        """FusedScaleMaskSoftmax fprop"""
        # [b, np, sq, sk]
        assert inp.dim() == 4
        self.input_in_fp16 = inp.dtype == torch.float16
        self.input_in_bf16 = inp.dtype == torch.bfloat16
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

261
262
263
264
        assert (
            scale is None or self.softmax_in_fp32
        ), "softmax should be in fp32 when scaled"

265
        if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode():
266
267
            return self.forward_fused_softmax(inp, mask, scale)
        return self.forward_torch_softmax(inp, mask, scale)
Przemek Tredak's avatar
Przemek Tredak committed
268
269
270
271
272
273
274
275
276
277
278
279
280

    def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
        """Check FusedScaleMaskSoftmax kernel availability based on size"""
        attn_batches = b * np

        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
            and 16 < sk <= 4096  # sk must be 16 ~ 2048
            and sq % 4 == 0  # sq must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
            if 0 <= sk <= 4096:
281
                batch_per_block = self.get_batch_per_block(int(sk))
Przemek Tredak's avatar
Przemek Tredak committed
282
283
284
285
286
287
288
289
290
291

                if self.attn_mask_type == "causal":
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False

    def forward_fused_softmax(
292
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
293
294
295
    ) -> torch.Tensor:
        """Fused masked softmax kernel"""
        b, np, sq, sk = inp.size()
296
        scale = 1.0 if scale is None else scale
Przemek Tredak's avatar
Przemek Tredak committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310

        if self.attn_mask_type == "causal":
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            inp = inp.view(-1, sq, sk)
            probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
            return probs.view(b, np, sq, sk)
        # input is 4D tensor (b, np, sq, sk)
        if mask is not None:
            return ScaledMaskedSoftmax.apply(inp, mask, scale)
        return ScaledSoftmax.apply(inp, scale)

    def forward_torch_softmax(
311
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
312
313
314
315
316
    ) -> torch.Tensor:
        """Framework softmax"""
        if self.input_in_float16 and self.softmax_in_fp32:
            inp = inp.float()

317
318
        if scale is not None:
            inp = inp * scale
319
320

        if self.attn_mask_type == "causal":
321
322
323
324
325
326
            if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
                seq_len_q, seq_len_k = inp.size(2), inp.size(3)
                assert self.kvcache_max_seq >= seq_len_k
                mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
            else:
                mask = _get_default_causal_mask(inp.size(2))
327

Przemek Tredak's avatar
Przemek Tredak committed
328
329
330
331
332
333
334
335
336
337
338
339
        mask_output = self.mask_func(inp, mask) if mask is not None else inp
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()

        return probs

    @staticmethod
340
    def get_batch_per_block(key_seq_len: int) -> int:
Przemek Tredak's avatar
Przemek Tredak committed
341
        """Softmax utility"""
342
343
344
        pow2 = 1 << (key_seq_len - 1).bit_length()
        warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
        batches_per_warp = 2 if pow2 <= 128 else 1
Jan Bielak's avatar
Jan Bielak committed
345
        warps_per_block = THREADS_PER_BLOCK // warp_size
346
347
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block