activation.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18


19
@CustomOp.register("fatrelu_and_mul")
20
21
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
22

23
24
25
26
27
28
29
30
31
32
33
34
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
35
        if current_platform.is_cuda_alike():
36
            self.op = torch.ops._C.fatrelu_and_mul
37
38
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
39
40
41
42
43
44
45
46
47

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
48
49
50
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
51
        self.op(out, x, self.threshold)
52
        return out
53
54


55
@CustomOp.register("silu_and_mul")
56
class SiluAndMul(CustomOp):
57
58
    """An activation function for SwiGLU.

59
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
    Shapes:
62
63
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
64
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
    def __init__(self):
        super().__init__()
68
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
69
70
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
71
72
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
73

74
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
75
76
77
78
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

79
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
80
81
82
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
83
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        return out
85

86
87
88
89
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
90
        self.op(out, x)
91
92
        return out

93
94
95
96
97
98
99
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

100

101
102
103
104
105
106
107
108
109
110
111
112
113
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
114
        if current_platform.is_cuda_alike():
115
116
117
118
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
119
120
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


138
@CustomOp.register("gelu_and_mul")
139
class GeluAndMul(CustomOp):
140
141
142
143
144
145
146
147
148
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

149
150
151
152
153
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
154
155
156
157
158
159
160
161
162
163
164
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
165

166
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
167
168
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
169
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
170

171
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
172
173
174
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
175
        self.op(out, x)
176
177
        return out

178
179
180
181
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
182
        self.op(out, x)
183
184
        return out

185
186
187
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

188

189
@CustomOp.register("gelu_new")
190
class NewGELU(CustomOp):
191

192
193
194
195
196
197
198
199
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

200
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
201
202
203
204
205
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

206
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
207
        out = torch.empty_like(x)
208
        self.op(out, x)
209
210
        return out

211
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
212
        return self.op(x)
213

214

215
@CustomOp.register("gelu_fast")
216
class FastGELU(CustomOp):
217

218
219
220
221
222
223
224
225
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

226
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
227
228
229
230
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

231
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
232
        out = torch.empty_like(x)
233
        self.op(out, x)
234
235
        return out

236
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
237
        return self.op(x)
238

239

240
@CustomOp.register("quick_gelu")
241
242
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
243
244
245
246
247
248
249
250
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

251
252
253
254
255
256
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
257
        self.op(out, x)
258
259
        return out

260
261
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
262
        self.op(out, x)
263
264
        return out

265
266
267
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

268

269
@CustomOp.register("relu2")
270
271
272
273
274
275
276
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
277
        return torch.square(F.relu(x))
278
279
280
281
282

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


283
284
285
286
287
288
289
290
291
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
292
293
294
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
295
296
297
    ):
        super().__init__()
        self.act = act_module
298
        self.input_is_parallel = input_is_parallel
299
300
301
302
303
304
305
306
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
307
        self.scales = nn.Parameter(
308
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
309
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
310

311
    def forward(self, x: torch.Tensor) -> torch.Tensor:
312
313
        return self.act(x) / self.scales

314
315
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
316
317
318
319
320
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
321
322
323
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

324

325
326
327
328
329
330
331
332
333
334
335
336
337
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
338
339
    "silu":
    lambda: nn.SiLU(),
340
341
342
    "quick_gelu":
    lambda: QuickGELU(),
})
343
344


345
def get_act_fn(act_fn_name: str) -> nn.Module:
346
    """Get an activation function by name."""
347
348
349
350
351
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

352
    return _ACTIVATION_REGISTRY[act_fn_name]
353
354
355
356
357


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
358
    "geglu": lambda: GeluAndMul(),
359
360
361
})


362
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
363
364
365
366
367
368
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

369
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]