softmax.py 10 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Fused scaled masked softmax functions"""
import os
7
from typing import Callable, Tuple, Union, Optional
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import torch
from torch import nn
10
import transformer_engine_torch as tex
Neta Zmora's avatar
Neta Zmora committed
11

12
13
14
15

THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128

Przemek Tredak's avatar
Przemek Tredak committed
16

17
18
_default_causal_mask = {}

19

20
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
21
    """Return the causal upper triangular mask for softmax input"""
22
23
24
25
    matrix_identifiers = (mask_type, sq, sk)
    if matrix_identifiers not in _default_causal_mask:
        diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
        _default_causal_mask[matrix_identifiers] = torch.triu(
26
27
            torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
        )
28
    return _default_causal_mask[matrix_identifiers]
29
30


Przemek Tredak's avatar
Przemek Tredak committed
31
32
33
34
35
36
37
38
39
40
41
42
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply upper triangular mask (typically used in gpt models).
    3. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledUpperTriangMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])
43
        softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
47
48

        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
49
    def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
Przemek Tredak's avatar
Przemek Tredak committed
50
51
        """ScaledUpperTriangMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors
52
        input_grads = tex.scaled_upper_triang_masked_softmax_backward(
Przemek Tredak's avatar
Przemek Tredak committed
53
54
55
56
57
58
            output_grads, softmax_results, scale_t[0]
        )

        return input_grads, None


59
60
61
62
63
64
65
66
67
68
69
70
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply causal mask aligned to the bottom right corner of the input matrix
    3. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledAlignedCausalMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])
71
        softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
72
73
74
75
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
76
    def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
77
78
79
80
81
82
83
84
85
        """ScaledAlignedCausalMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors
        input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
            output_grads, softmax_results, scale_t[0]
        )

        return input_grads, None


Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
89
90
91
92
93
94
class ScaledMaskedSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following three operations in sequence
    1. Scale the tensor.
    2. Apply the mask.
    3. Perform softmax.
    """

    @staticmethod
95
    def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
96
97
98
        """ScaledMaskedSoftmax fwd"""
        scale_t = torch.tensor([scale])

99
        softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
100
101
102
103
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
104
    def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
Przemek Tredak's avatar
Przemek Tredak committed
105
106
107
        """ScaledMaskedSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

108
        input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        return input_grads, None, None


class ScaledSoftmax(torch.autograd.Function):
    """
    Fused operation which performs following two operations in sequence
    1. Scale the tensor.
    2. Perform softmax.
    """

    @staticmethod
    def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
        """ScaledSoftmax fwd"""
        scale_t = torch.tensor([scale])

124
        softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
125
126
127
128
        ctx.save_for_backward(softmax_results, scale_t)
        return softmax_results

    @staticmethod
129
    def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
Przemek Tredak's avatar
Przemek Tredak committed
130
131
132
        """ScaledSoftmax bwd"""
        softmax_results, scale_t = ctx.saved_tensors

133
        input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
Przemek Tredak's avatar
Przemek Tredak committed
134
135
        return input_grads, None, None

136

Przemek Tredak's avatar
Przemek Tredak committed
137
138
139
140
141
142
143
144
145
146
147
148
class FusedScaleMaskSoftmax(nn.Module):
    """
    fused operation: scaling + mask + softmax

    Arguments:
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
    """

    def __init__(
        self,
        mask_func: Callable,
149
        softmax_in_fp32: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
150
151
    ) -> None:
        super().__init__()
152
        self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")))
Przemek Tredak's avatar
Przemek Tredak committed
153
154
155
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32

156
157
158
159
    def forward(
        self,
        inp: torch.Tensor,
        mask: torch.Tensor,
160
        attn_mask_type: str,
161
162
        scale: Optional[float] = None,
    ) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
166
167
168
        """FusedScaleMaskSoftmax fprop"""
        # [b, np, sq, sk]
        assert inp.dim() == 4
        self.input_in_fp16 = inp.dtype == torch.float16
        self.input_in_bf16 = inp.dtype == torch.bfloat16
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
169
        self.attn_mask_type = attn_mask_type
Przemek Tredak's avatar
Przemek Tredak committed
170

171
        assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
172

173
        if self.is_kernel_available(mask, *inp.size()):
174
175
            return self.forward_fused_softmax(inp, mask, scale)
        return self.forward_torch_softmax(inp, mask, scale)
Przemek Tredak's avatar
Przemek Tredak committed
176

177
    def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
Przemek Tredak's avatar
Przemek Tredak committed
178
179
180
        """Check FusedScaleMaskSoftmax kernel availability based on size"""
        attn_batches = b * np

181
182
183
184
185
186
187
188
        if not self.scaled_masked_softmax_fusion:
            return False  # user doesn't want to fuse
        if not self.input_in_float16:
            return False  # input must be fp16
        if not 16 < sk < 16384:
            return False  # sk must be 16 ~ 16384
        if sk % 8 != 0:
            return False  # sk must be divisor of 8
189
190
        if sq == 1:
            return False  # sq must be > 1
191
192
        if self.attn_mask_type == "causal" and sq != sk:
            return False  # Fused causal kernel only support causal_bottom_right
193

194
195
196
        if (
            sq % 4 == 0  # sq must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
Przemek Tredak's avatar
Przemek Tredak committed
197
        ):
198
            batch_per_block = self.get_batch_per_block(int(sk))
199
            if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary":
200
201
202
                if (
                    mask is not None
                    and sq % batch_per_block == 0
203
204
                    and mask.shape[0] in [1, b]
                    and mask.shape[1:] == (1, sq, sk)
205
206
207
208
209
                ):
                    return True
            else:
                if sq % batch_per_block == 0:
                    return True
Przemek Tredak's avatar
Przemek Tredak committed
210
211
212
        return False

    def forward_fused_softmax(
213
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
214
    ) -> torch.Tensor:
215
216
217
218
219
220
221
222
223
        """
        Fused masked softmax path.
          attn_mask_type                                       | module
        -----------------------------------------------------------------------------------------
          no_mask                                              | ScaledSoftmax
          causal (self-attention), causal_bottom_right         | ScaledAlignedCausalMaskedSoftmax
          padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax
          arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk])         | ScaledMaskedSoftmax
        """
224
        scale = 1.0 if scale is None else scale
Przemek Tredak's avatar
Przemek Tredak committed
225

226
227
228
        # Disable for now until unalignment bug is fixed.
        # if self.attn_mask_type in ["causal", "causal_bottom_right"]:
        #    return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
Przemek Tredak's avatar
Przemek Tredak committed
229

230
        # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk)
231
        if mask is not None and self.attn_mask_type != "no_mask":
Przemek Tredak's avatar
Przemek Tredak committed
232
233
234
235
            return ScaledMaskedSoftmax.apply(inp, mask, scale)
        return ScaledSoftmax.apply(inp, scale)

    def forward_torch_softmax(
236
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
Przemek Tredak's avatar
Przemek Tredak committed
237
238
239
240
241
    ) -> torch.Tensor:
        """Framework softmax"""
        if self.input_in_float16 and self.softmax_in_fp32:
            inp = inp.float()

242
243
        if scale is not None:
            inp = inp * scale
244

245
        if self.attn_mask_type in ["causal", "causal_bottom_right"]:
246
            seq_len_q, seq_len_k = inp.size(2), inp.size(3)
247
            causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
248
249
            if mask is None:
                mask = causal_mask
250
            else:
251
                mask = torch.logical_or(mask, causal_mask)
252

253
254
255
        mask_output = inp
        if mask is not None and self.attn_mask_type != "no_mask":
            mask_output = self.mask_func(inp, mask)
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
259
260
261
262
263
264
265
266
        probs = torch.nn.Softmax(dim=-1)(mask_output)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = probs.half()
            else:
                probs = probs.bfloat16()

        return probs

    @staticmethod
267
    def get_batch_per_block(key_seq_len: int) -> int:
Przemek Tredak's avatar
Przemek Tredak committed
268
        """Softmax utility"""
269
270
271
        pow2 = 1 << (key_seq_len - 1).bit_length()
        warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
        batches_per_warp = 2 if pow2 <= 128 else 1
Jan Bielak's avatar
Jan Bielak committed
272
        warps_per_block = THREADS_PER_BLOCK // warp_size
273
274
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block