softmax.py 6.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fused scaled masked softmax functions"""

from typing import Callable

import os
import transformer_engine_tensorflow as tex
import tensorflow as tf

from .module import get_stream_id

THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128


_default_causal_mask = {}

def _get_default_causal_mask(sq: int) -> tf.Tensor:
    """Return the causal upper triangular mask for softmax input"""
    if sq not in _default_causal_mask:
        # In TF, the mask specifies 1 to keep and 0 to mask. In "causal" mask
        # mode, we compute the softmax of the lower triangular.
        mask_operator = tf.linalg.LinearOperatorLowerTriangular(
            tf.ones((sq, sq), dtype=tf.bool))
        mask = mask_operator.to_dense()
        _default_causal_mask[sq] = mask
    return _default_causal_mask[sq]

class FusedScaleMaskSoftmax(tf.keras.Model):
    """
    fused operation: scaling + mask + softmax

    Arguments:
        attn_mask_type: attention mask type (pad or causal)
        mask_func: mask function to be applied.
        softmax_in_fp32: if true, softmax in performed at fp32 precision.
        scale: scaling factor used in input tensor scaling.
    """

    def __init__(
        self,
        attn_mask_type: str,
        mask_func: Callable,
        softmax_in_fp32: bool,
        scale: float,
    ) -> None:
        super().__init__()
        self.attn_mask_type = attn_mask_type
        self.scaled_masked_softmax_fusion = bool(
            int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
        )
        self.mask_func = mask_func
        self.softmax_in_fp32 = softmax_in_fp32
        self.scale = scale
        self.stream = get_stream_id()

        assert (
            self.scale is None or softmax_in_fp32
        ), "softmax should be in fp32 when scaled"

    def __call__(self, inp: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
        """FusedScaleMaskSoftmax fprop"""
        # [b, np, sq, sk]
        assert len(inp.shape) == 4
        self.input_in_fp16 = inp.dtype == tf.float16
        self.input_in_bf16 = inp.dtype == tf.bfloat16
        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

        if self.is_kernel_available(*inp.shape):
            return self.forward_fused_softmax(inp, mask)
        return self.forward_tf_softmax(inp, mask)

    def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
        """Check FusedScaleMaskSoftmax kernel availability based on size"""
        attn_batches = b * np

        if (
            self.scaled_masked_softmax_fusion  # user want to fuse
            and self.input_in_float16  # input must be fp16
            and 16 < sk <= 4096  # sk must be 16 ~ 2048
            and sq % 4 == 0  # sq must be divisor of 4
            and attn_batches % 4 == 0  # np * b must be divisor of 4
        ):
            if 0 <= sk <= 4096:
                batch_per_block = self.get_batch_per_block(int(sk))

                if self.attn_mask_type == "causal":
                    if attn_batches % batch_per_block == 0:
                        return True
                else:
                    if sq % batch_per_block == 0:
                        return True
        return False

    @tf.custom_gradient
    def scaled_masked_softmax(self, x: tf.Tensor, mask: tf.Tensor,
                              scale: float):
        """Scaled masked softmax."""
        y = tex.scaled_masked_softmax_forward(x, mask, scale, self.stream)

        def grad_fn(upstream):
            dx = tex.scaled_masked_softmax_backward(upstream, y, scale,
                                                    self.stream)
            return dx, None, None

        return y, grad_fn

    @tf.custom_gradient
    def scaled_softmax(self, x: tf.Tensor, scale: float):
        """Scaled softmax."""
        y = tex.scaled_softmax_forward(x, scale, self.stream)

        def grad_fn(upstream):
            dx = tex.scaled_softmax_backward(upstream, y, scale, self.stream)
            return dx, None

        return y, grad_fn

    @tf.custom_gradient
    def scaled_upper_triang_masked_softmax(self, x: tf.Tensor, scale: float):
        """Scaled upper triangular masked softmax."""
        y = tex.scaled_upper_triang_masked_softmax_forward(x, scale,
                                                           self.stream)

        def grad_fn(upstream):
            dx = tex.scaled_upper_triang_masked_softmax_backward(
                upstream, y, scale, self.stream
            )
            return dx, None

        return y, grad_fn

    def forward_fused_softmax(
        self,
        inp: tf.Tensor,
        mask: tf.Tensor,
    ) -> tf.Tensor:
        """Fused masked softmax kernel"""
        sq, sk = inp.shape[2], inp.shape[3]
        scale = self.scale if self.scale is not None else 1.0

        if self.attn_mask_type == "causal":
            assert sq == sk, "causal mask is only for self attention"

            # input is 3D tensor (attn_batches, sq, sk)
            inp = tf.reshape(inp, (-1, sq, sk))
            probs = self.scaled_upper_triang_masked_softmax(inp, scale)
            return tf.reshape(probs, inp.shape)
        # input is 4D tensor (b, np, sq, sk)
        if mask is not None:
            # The mask defined in TE kernels are different from TF. In TE, the
            # mask specifies 1 to mask out and 0 to keep.
            mask = tf.math.logical_not(mask)

            ndims = len(mask.shape)
            assert ndims <= 4, "mask ndims should be <= 4"
            if len(mask.shape) < 4:
                # Broadcasting the first dims of mask to match the input ndims.
                broadcast_shape = [1] * (4 - ndims) + mask.shape[:]
                mask = tf.reshape(mask, broadcast_shape)
            return self.scaled_masked_softmax(inp, mask, scale)
        return self.scaled_softmax(inp, scale)

    def forward_tf_softmax(self, inp: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
        """Framework softmax"""
        if self.input_in_float16 and self.softmax_in_fp32:
            inp = tf.cast(inp, tf.float32)

        if self.scale is not None:
            inp = inp * self.scale

        if self.attn_mask_type == "causal":
            mask = _get_default_causal_mask(inp.shape[2])

        mask_output = self.mask_func(inp, mask) if mask is not None else inp
        probs = tf.nn.softmax(mask_output, axis=-1)

        if self.input_in_float16 and self.softmax_in_fp32:
            if self.input_in_fp16:
                probs = tf.cast(probs, tf.half)
            else:
                probs = tf.cast(probs, tf.bfloat16)

        return probs

    @staticmethod
    def get_batch_per_block(key_seq_len: int) -> int:
        """Softmax utility"""
        pow2 = 1 << (key_seq_len - 1).bit_length()
        warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = THREADS_PER_BLOCK / warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block