activation.py 9.86 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.utils import set_weight_attrs
14
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
15
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17


18
@CustomOp.register("fatrelu_and_mul")
19
20
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


46
@CustomOp.register("silu_and_mul")
47
class SiluAndMul(CustomOp):
48
49
    """An activation function for SwiGLU.

50
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
51

52
    Shapes:
53
54
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
55
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
58
59
60
61
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

62
63
64
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

65
66
67
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
68
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
69
            ops.silu_and_mul_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
70
71
        else:
            ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        return out
73

74
75
76
77
78
79
80
81
82
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

83

84
@CustomOp.register("gelu_and_mul")
85
class GeluAndMul(CustomOp):
86
87
88
89
90
91
92
93
94
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

95
96
97
98
99
100
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

101
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
102
103
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
104
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
105

106
107
108
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

109
110
111
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
112
        if self.approximate == "none":
113
            if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
114
115
116
                ops.gelu_and_mul_opt(out, x)
            else:
                ops.gelu_and_mul(out, x)
117
        elif self.approximate == "tanh":
118
            if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
119
120
121
                ops.gelu_tanh_and_mul_opt(out, x)
            else:
                ops.gelu_tanh_and_mul(out, x)
122
123
        return out

124
125
126
127
128
129
130
131
132
133
134
135
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

136
137
138
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

139

140
@CustomOp.register("gelu_new")
141
class NewGELU(CustomOp):
142

143
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
144
145
146
147
148
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

149
150
151
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

152
        out = torch.empty_like(x)
153
        ops.gelu_new(out, x)
154
155
        return out

156
157
158
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

159
        return ops.gelu_new(x)
160

161

162
@CustomOp.register("gelu_fast")
163
class FastGELU(CustomOp):
164

165
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
166
167
168
169
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

170
171
172
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

173
        out = torch.empty_like(x)
174
        ops.gelu_fast(out, x)
175
176
        return out

177
178
179
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

180
        return ops.gelu_fast(x)
181

182

183
@CustomOp.register("quick_gelu")
184
185
186
187
188
189
190
191
192
193
194
195
196
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

197
198
199
200
201
202
203
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

204
205
206
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

207

208
@CustomOp.register("relu2")
209
210
211
212
213
214
215
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
216
        return torch.square(F.relu(x))
217
218
219
220

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)

221

222
223
224
225
226
227
228
229
230
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
231
232
233
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
234
235
236
    ):
        super().__init__()
        self.act = act_module
237
        self.input_is_parallel = input_is_parallel
238
239
240
241
242
243
244
245
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
246
        self.scales = nn.Parameter(
247
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
248
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
249

250
    def forward(self, x: torch.Tensor) -> torch.Tensor:
251
252
        return self.act(x) / self.scales

253
254
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
255
256
257
258
259
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
260
261
262
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

263

264
265
266
267
268
269
270
271
272
273
274
275
276
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
277
278
    "silu":
    lambda: nn.SiLU(),
279
280
281
    "quick_gelu":
    lambda: QuickGELU(),
})
282
283


284
285
286
287
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
288
289
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
290
) -> nn.Module:
291
    """Get an activation function by name."""
292
293
294
295
296
297
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
298
299
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
300
301
302
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
303
304
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
zhuwenwen's avatar
zhuwenwen committed
305
    return act_fn