softmax.py 15.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Fused scaled masked softmax functions"""
import os
7
from typing import Callable, Tuple, Union, Optional
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import torch
from torch import nn
10
11
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
12
import transformer_engine_extensions as tex
13
from transformer_engine.pytorch.export import is_in_onnx_export_mode
Neta Zmora's avatar
Neta Zmora committed
14
15
from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32

16
17
18
19

THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128

Przemek Tredak's avatar
Przemek Tredak committed
20

21
22
_default_causal_mask = {}

23
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
24
    """Return the causal upper triangular mask for softmax input"""
25
26
27
28
29
30
31
32
33
34
    if sq == 1:
        return torch.zeros((1, sk), dtype=torch.bool, device="cuda")

    matrix_shape = (sq, sk)
    if matrix_shape not in _default_causal_mask:
        diagonal_offset = sk - sq + 1
        _default_causal_mask[matrix_shape] = torch.triu(
            torch.ones(sq, sk, dtype=torch.bool, device="cuda"),
            diagonal=diagonal_offset)
    return _default_causal_mask[matrix_shape]
35
36


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _get_onnx_export_causal_mask(
    seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor
) -> torch.Tensor:
    """Return the causal upper triangular mask for softmax input, for ONNX export.

    ONNX does not support dynamic control-flow and requires non-square masks when
    using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1).

    Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct
    shape for GPT context and generation phases.
    In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in
    the generation phase the mask is rectangular with shape (1, seq_k).
    """
    assert len(onnx_causal_mask.size()) == 2
    assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1)
    assert onnx_causal_mask.size(0) >= (seq_k-seq_q) >= 0
    derived_mask = onnx_causal_mask[seq_k-seq_q:seq_k, :seq_k]
    return derived_mask


Neta Zmora's avatar
Neta Zmora committed
57
58
59
60
61
62
63
def fp32_compute(onnx_symbolic_fn):
    """A decorator that wraps an ONNX symoblic function with FP32 compute operators."""
    def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs):
        return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs)
    return wrapper


Przemek Tredak's avatar
Przemek Tredak committed
64
65
66
67
68
69
70
71
72
73
74
75
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledUpperTriangMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])
76
        softmax_results = tex.scaled_upper_triang_masked_softmax_forward(
Przemek Tredak's avatar
Przemek Tredak committed
77
78
79
80
81
82
83
84
85
86
87
88
            inputs, scale_t[0]
        )

        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledUpperTriangMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors
89
        input_grads = tex.scaled_upper_triang_masked_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
90
91
92
93
94
            output_grads, softmax_results, scale_t[0]
        )

        return input_grads, None

95
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
96
    @fp32_compute
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
        """ScaledUpperTriangMaskedSoftmax symbolic method"""
        def triangular_mask():
            dtype =  _type_utils.JitScalarType.INT64
            ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
            k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
            mask = g.op("Trilu", ones, k, upper_i=1)
            mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
            return mask

        # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
        mask = triangular_mask()
        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
        inv_mask = g.op("Sub", one, mask)

        neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
        softmax_mask = g.op("Mul", mask, neg_tenK)

        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        masked_scaled = g.op("Mul", inv_mask, scaled)
        masked = g.op("Add", masked_scaled, softmax_mask)
        out = g.op("Softmax", masked)
        return out

Przemek Tredak's avatar
Przemek Tredak committed
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply causal mask aligned to the bottom right corner of the input matrix
    3. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledAlignedCausalMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])
        softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(
            inputs, scale_t[0]
        )
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledAlignedCausalMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors
        input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
            output_grads, softmax_results, scale_t[0]
        )

        return input_grads, None

    @staticmethod
    @fp32_compute
    def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
        """ScaledAlignedCausalMaskedSoftmax symbolic method"""
        def triangular_mask():
            dtype =  _type_utils.JitScalarType.INT64
            ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
            k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))

            # rectangular causal mask aligned to the bottom right corner of Attention matrix
            rows = inputs.size(dim=-2)
            cols = inputs.size(dim=-1)
            diag_shift = cols - rows + 1

            mask = g.op("Trilu", ones, k, upper_i=diag_shift)
            mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
            return mask

        # Captures the logic of function scaled_aligned_masked_softmax_warp_forward
        mask = triangular_mask()
        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
        inv_mask = g.op("Sub", one, mask)

        neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
        softmax_mask = g.op("Mul", mask, neg_tenK)

        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        masked_scaled = g.op("Mul", inv_mask, scaled)
        masked = g.op("Add", masked_scaled, softmax_mask)
        out = g.op("Softmax", masked)
        return out


Przemek Tredak's avatar
Przemek Tredak committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class ScaledMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
    """

    @staticmethod
    def forward(
        ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float
    ) -> torch.Tensor:
        """ScaledMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])

202
        softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
203
204
205
206
207
208
209
210
211
212
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

213
        input_grads = tex.scaled_masked_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
214
215
216
217
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None

218
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
219
    @fp32_compute
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    def symbolic(
        g: torch.Graph,
        inputs: torch._C.Value,
        mask: torch._C.Value,
        scale: float) -> torch._C.Value:
        """ScaledMaskedSoftmax symbolic method"""
        # Captures the logic of function scaled_masked_softmax_warp_forward.
        # output = softmax(mask(input*scale)
        # Computed as:
        #   masked_scaled = (1 - mask)*(input*scale)
        #   softmax_mask = mask * -10000
        #   output = softmax(masked_scaled + softmax_mask)
        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
        inv_mask = g.op("Sub", one, mask)
        # Note: type is hard coded because softmax uses FP16 or BF16
        neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
        softmax_mask = g.op("Mul", mask, neg_tenK)
        masked_scaled = g.op("Mul", inv_mask, scaled)
        masked = g.op("Add", masked_scaled, softmax_mask)
        out = g.op("Softmax", masked)
        return out

Przemek Tredak's avatar
Przemek Tredak committed
244
245
246
247
248
249
250
251
252
253
254
255
256

class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledSoftmax fwd"""
        scale_t = torch.tensor([scale])

257
        softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
258
259
260
261
262
263
264
265
266
267
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
    def backward(
        ctx, output_grads: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """ScaledSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

268
        input_grads = tex.scaled_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
269
270
271
272
            output_grads, softmax_results, scale_t[0]
        )
        return input_grads, None, None

273
    @staticmethod
Neta Zmora's avatar
Neta Zmora committed
274
    @fp32_compute
275
276
277
278
279
280
281
282
    def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
        """ScaledSoftmax symbolic method"""
        scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
        scaled = g.op("Mul", inputs, scale_input)
        out = g.op("Softmax", scaled)
        return out


Przemek Tredak's avatar
Przemek Tredak committed
283
284
285
286
287
288
289
290
291
292
293
294
295

class FusedScaleMaskSoftmax(nn.Module):
    """
    fused operation: scaling + mask + softmax

    Arguments:
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
    """

    def __init__(
        self,
        mask_func: Callable,
296
        softmax_in_fp32: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
297
298
299
300
301
302
303
304
    ) -> None:
        super().__init__()
        self.scaled_masked_softmax_fusion = bool(
            int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
        )
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32

305
306
307
308
309
310
311
312
313
314
315
        # Users exporting to ONNX can optimize the attention mask for GPT text generation.
        self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1"))
        if self.kvcache_max_seq > 0:
            self.register_buffer(
                "onnx_causal_mask",
                torch.triu(
                    torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"),
                    diagonal=1
                ).bool(),
                persistent=False)

316
317
318
319
    def forward(
        self,
        inp: torch.Tensor,
        mask: torch.Tensor,
320
        attn_mask_type: str,
321
322
        scale: Optional[float] = None,
    ) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
323
324
325
326
327
328
        """FusedScaleMaskSoftmax fprop"""
        # [b, np, sq, sk]
        assert inp.dim() == 4
        self.input_in_fp16 = inp.dtype == torch.float16
        self.input_in_bf16 = inp.dtype == torch.bfloat16
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
329
        self.attn_mask_type = attn_mask_type
Przemek Tredak's avatar
Przemek Tredak committed
330

331
332
333
334
        assert (
            scale is None or self.softmax_in_fp32
        ), "softmax should be in fp32 when scaled"

335
        if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
336
337
            return self.forward_fused_softmax(inp, mask, scale)
        return self.forward_torch_softmax(inp, mask, scale)
Przemek Tredak's avatar
Przemek Tredak committed
338

339
    def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements
Przemek Tredak's avatar
Przemek Tredak committed
340
341
342
        """Check FusedScaleMaskSoftmax kernel availability based on size"""
        attn_batches = b * np

343
344
345
346
347
348
349
350
351
352
        if not self.scaled_masked_softmax_fusion:
            return False  # user doesn't want to fuse
        if not self.input_in_float16:
            return False  # input must be fp16
        if not 16 < sk < 16384:
            return False  # sk must be 16 ~ 16384
        if sk % 8 != 0:
            return False  # sk must be divisor of 8
        if self.attn_mask_type == "arbitrary":
            return False  # Custom masks not supported
353
354
355
356
357
358

        if self.attn_mask_type == "causal":         # unfused causal softmax kernel
            return True

        if (sq % 4 == 0                             # sq must be divisor of 4
            and attn_batches % 4 == 0               # np * b must be divisor of 4
359
            and self.attn_mask_type != "arbitrary"  # Custom masks not supported
Przemek Tredak's avatar
Przemek Tredak committed
360
        ):
361
362
363
364
365
366
367
368
369
370
371
372
373
            batch_per_block = self.get_batch_per_block(int(sk))

            if self.attn_mask_type == "padding":
                if (
                    mask is not None
                    and sq % batch_per_block == 0
                    and mask.shape[-2] == sq
                    and mask.shape[-1] == sk
                ):
                    return True
            else:
                if sq % batch_per_block == 0:
                    return True
Przemek Tredak's avatar
Przemek Tredak committed
374
375
376
        return False

    def forward_fused_softmax(
377
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
378
379
    ) -> torch.Tensor:
        """Fused masked softmax kernel"""
380
        scale = 1.0 if scale is None else scale
Przemek Tredak's avatar
Przemek Tredak committed
381
382

        if self.attn_mask_type == "causal":
383
            return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
Przemek Tredak's avatar
Przemek Tredak committed
384
385

        # input is 4D tensor (b, np, sq, sk)
386
        if mask is not None and self.attn_mask_type != "no_mask":
Przemek Tredak's avatar
Przemek Tredak committed
387
388
389
390
            return ScaledMaskedSoftmax.apply(inp, mask, scale)
        return ScaledSoftmax.apply(inp, scale)

    def forward_torch_softmax(
391
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
395
396
    ) -> torch.Tensor:
        """Framework softmax"""
        if self.input_in_float16 and self.softmax_in_fp32:
            inp = inp.float()

397
398
        if scale is not None:
            inp = inp * scale
399
400

        if self.attn_mask_type == "causal":
401
            seq_len_q, seq_len_k = inp.size(2), inp.size(3)
402
403
404
405
            if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
                assert self.kvcache_max_seq >= seq_len_k
                mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
            else:
406
                mask = _get_default_causal_mask(seq_len_q, seq_len_k)
407

408
409
410
        mask_output = inp
        if mask is not None and self.attn_mask_type != "no_mask":
            mask_output = self.mask_func(inp, mask)
Przemek Tredak's avatar
Przemek Tredak committed
411
412
413
414
415
416
417
418
419
420
421
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()

        return probs

    @staticmethod
422
    def get_batch_per_block(key_seq_len: int) -> int:
Przemek Tredak's avatar
Przemek Tredak committed
423
        """Softmax utility"""
424
425
426
        pow2 = 1 << (key_seq_len - 1).bit_length()
        warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
        batches_per_warp = 2 if pow2 <= 128 else 1
Jan Bielak's avatar
Jan Bielak committed
427
        warps_per_block = THREADS_PER_BLOCK // warp_size
428
429
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block