activation.py 9.54 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.utils import set_weight_attrs
14
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16


17
@CustomOp.register("fatrelu_and_mul")
18
19
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


45
@CustomOp.register("silu_and_mul")
46
class SiluAndMul(CustomOp):
47
48
    """An activation function for SwiGLU.

49
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
50

51
    Shapes:
52
53
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
54
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
57
58
59
60
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

61
62
63
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

64
65
66
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
67
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        return out
69

70
71
72
73
74
75
76
77
78
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

79

80
@CustomOp.register("gelu_and_mul")
81
class GeluAndMul(CustomOp):
82
83
84
85
86
87
88
89
90
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

91
92
93
94
95
96
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

97
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
98
99
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
100
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
101

102
103
104
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

105
106
107
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
108
109
110
111
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
112
113
        return out

114
115
116
117
118
119
120
121
122
123
124
125
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

126
127
128
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

129

130
@CustomOp.register("gelu_new")
131
class NewGELU(CustomOp):
132

133
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
134
135
136
137
138
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

139
140
141
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

142
        out = torch.empty_like(x)
143
        ops.gelu_new(out, x)
144
145
        return out

146
147
148
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

149
        return ops.gelu_new(x)
150

151

152
@CustomOp.register("gelu_fast")
153
class FastGELU(CustomOp):
154

155
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
156
157
158
159
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

160
161
162
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

163
        out = torch.empty_like(x)
164
        ops.gelu_fast(out, x)
165
166
        return out

167
168
169
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

170
        return ops.gelu_fast(x)
171

172

173
@CustomOp.register("quick_gelu")
174
175
176
177
178
179
180
181
182
183
184
185
186
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

187
188
189
190
191
192
193
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

194
195
196
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

197

198
@CustomOp.register("relu2")
199
200
201
202
203
204
205
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
206
        return torch.square(F.relu(x))
207
208
209
210
211

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


212
213
214
215
216
217
218
219
220
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
221
222
223
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
224
225
226
    ):
        super().__init__()
        self.act = act_module
227
        self.input_is_parallel = input_is_parallel
228
229
230
231
232
233
234
235
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
236
        self.scales = nn.Parameter(
237
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
238
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
239

240
    def forward(self, x: torch.Tensor) -> torch.Tensor:
241
242
        return self.act(x) / self.scales

243
244
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
245
246
247
248
249
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
250
251
252
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

253

254
255
256
257
258
259
260
261
262
263
264
265
266
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
267
268
    "silu":
    lambda: nn.SiLU(),
269
270
271
    "quick_gelu":
    lambda: QuickGELU(),
})
272
273


274
275
276
277
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
278
279
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
280
) -> nn.Module:
281
    """Get an activation function by name."""
282
283
284
285
286
287
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
288
289
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
290
291
292
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
293
294
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
295
    return act_fn