activation.py 9.16 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.utils import set_weight_attrs
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
    
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


43
class SiluAndMul(CustomOp):
44
45
    """An activation function for SwiGLU.

46
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
47

48
    Shapes:
49
50
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
51
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
54
55
56
57
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

58
59
60
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

61
62
63
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
64
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        return out
66

67
68
69
70
71
72
73
74
75
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

76

77
class GeluAndMul(CustomOp):
78
79
80
81
82
83
84
85
86
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

87
88
89
90
91
92
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

93
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
94
95
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
96
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
97

98
99
100
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

101
102
103
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
104
105
106
107
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
108
109
        return out

110
111
112
113
114
115
116
117
118
119
120
121
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

122
123
124
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

125

126
class NewGELU(CustomOp):
127

128
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
129
130
131
132
133
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

134
135
136
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

137
        out = torch.empty_like(x)
138
        ops.gelu_new(out, x)
139
140
        return out

141
142
143
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

144
        return ops.gelu_new(x)
145

146

147
class FastGELU(CustomOp):
148

149
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
150
151
152
153
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

154
155
156
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

157
        out = torch.empty_like(x)
158
        ops.gelu_fast(out, x)
159
160
        return out

161
162
163
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

164
        return ops.gelu_fast(x)
165

166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
class QuickGELU(CustomOp):

    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

181
182
183
184
185
186
187
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

188
189
190
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

191

192
193
194
195
196
197
198
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
199
        return torch.square(F.relu(x))
200
201
202
203
204

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


205
206
207
208
209
210
211
212
213
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
214
215
216
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
217
218
219
    ):
        super().__init__()
        self.act = act_module
220
        self.input_is_parallel = input_is_parallel
221
222
223
224
225
226
227
228
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
229
        self.scales = nn.Parameter(
230
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
231
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
232

233
    def forward(self, x: torch.Tensor) -> torch.Tensor:
234
235
        return self.act(x) / self.scales

236
237
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
238
239
240
241
242
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
243
244
245
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

246

247
248
249
250
251
252
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
253
    "relu2": ReLUSquaredActivation(),
254
    "quick_gelu": QuickGELU(),
255
256
257
}


258
259
260
261
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
262
263
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
264
) -> nn.Module:
265
    """Get an activation function by name."""
266
267
268
269
270
271
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
272
273
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
274
275
276
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
277
278
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
279
    return act_fn