activation.py 9.48 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.utils import set_weight_attrs
zhuwenwen's avatar
zhuwenwen committed
14
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16


17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
    
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


44
class SiluAndMul(CustomOp):
45
46
    """An activation function for SwiGLU.

47
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
    Shapes:
50
51
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
52
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
55
56
57
58
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

59
60
61
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

62
63
64
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
65
66
67
68
        if envs.VLLM_USE_OPT_OP:
            ops.silu_and_mul_opt(out, x)
        else:
            ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        return out
70

71
72
73
74
75
76
77
78
79
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

80

81
class GeluAndMul(CustomOp):
82
83
84
85
86
87
88
89
90
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

91
92
93
94
95
96
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

97
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
98
99
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
100
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
101

102
103
104
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

105
106
107
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
108
        if self.approximate == "none":
zhuwenwen's avatar
zhuwenwen committed
109
110
111
112
            if envs.VLLM_USE_OPT_OP:
                ops.gelu_and_mul_opt(out, x)
            else:
                ops.gelu_and_mul(out, x)
113
        elif self.approximate == "tanh":
zhuwenwen's avatar
zhuwenwen committed
114
115
116
117
            if envs.VLLM_USE_OPT_OP:
                ops.gelu_tanh_and_mul_opt(out, x)
            else:
                ops.gelu_tanh_and_mul(out, x)
118
119
        return out

120
121
122
123
124
125
126
127
128
129
130
131
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

132
133
134
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

135

136
class NewGELU(CustomOp):
137

138
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
139
140
141
142
143
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

144
145
146
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

147
        out = torch.empty_like(x)
148
        ops.gelu_new(out, x)
149
150
        return out

151
152
153
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

154
        return ops.gelu_new(x)
155

156

157
class FastGELU(CustomOp):
158

159
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
160
161
162
163
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

164
165
166
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

167
        out = torch.empty_like(x)
168
        ops.gelu_fast(out, x)
169
170
        return out

171
172
173
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

174
        return ops.gelu_fast(x)
175

176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
class QuickGELU(CustomOp):

    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

191
192
193
194
195
196
197
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

198
199
200
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

201

202
203
204
205
206
207
208
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
209
        return torch.square(F.relu(x))
210
211
212
213
214

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


215
216
217
218
219
220
221
222
223
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
224
225
226
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
227
228
229
    ):
        super().__init__()
        self.act = act_module
230
        self.input_is_parallel = input_is_parallel
231
232
233
234
235
236
237
238
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
239
        self.scales = nn.Parameter(
240
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
241
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
242

243
    def forward(self, x: torch.Tensor) -> torch.Tensor:
244
245
        return self.act(x) / self.scales

246
247
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
248
249
250
251
252
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
253
254
255
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

256

257
258
259
260
261
262
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
263
    "relu2": ReLUSquaredActivation(),
264
    "quick_gelu": QuickGELU(),
265
266
267
}


268
269
270
271
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
272
273
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
274
) -> nn.Module:
275
    """Get an activation function by name."""
276
277
278
279
280
281
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
282
283
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
284
285
286
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
287
288
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
289
    return act_fn