activation.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Custom activation functions."""
3
import math
4
5
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
12
from vllm.model_executor.custom_op import CustomOp
13
from vllm.model_executor.utils import set_weight_attrs
14
from vllm.platforms import current_platform
15
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
16
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18


19
@CustomOp.register("fatrelu_and_mul")
20
21
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
22

23
24
25
26
27
28
29
30
31
32
33
34
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
35
        if current_platform.is_cuda_alike():
36
            self.op = torch.ops._C.fatrelu_and_mul
37
38
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
39
40
41
42
43
44
45
46
47

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
48
49
50
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
51
        self.op(out, x, self.threshold)
52
        return out
53
54


55
@CustomOp.register("silu_and_mul")
56
class SiluAndMul(CustomOp):
57
58
    """An activation function for SwiGLU.

59
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
    Shapes:
62
63
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
64
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
    def __init__(self):
        super().__init__()
68
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
69
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
70
            self.op_opt = torch.ops._C.silu_and_mul_opt
71
        elif current_platform.is_xpu():
72
73
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
74

75
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
76
77
78
79
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

80
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
81
82
83
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
84
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
85
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
86
        else:
zhuwenwen's avatar
zhuwenwen committed
87
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        return out
89

90
91
92
93
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
94
        self.op(out, x)
95
96
        return out

97
98
99
100
101
102
103
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

104

105
106
107
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
108

109
110
111
112
113
114
115
116
117
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
118
        if current_platform.is_cuda_alike():
119
120
121
122
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
123
124
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
125
126
127
128
129
130
131

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
132
133
134
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
135
        self.op(out, x)
136
137
        return out

138
139
140
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

141

142
@CustomOp.register("gelu_and_mul")
143
class GeluAndMul(CustomOp):
144
145
146
147
148
149
150
151
152
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

153
154
155
156
157
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
158
159
160
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
161
                self.op_opt = torch.ops._C.gelu_and_mul_opt
162
163
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
164
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
165
166
167
168
169
170
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
171

172
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
173
174
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
175
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
176

177
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
178
179
180
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
181
182
183
184
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
185
186
        return out

187
188
189
190
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
191
        self.op(out, x)
192
193
        return out

194
195
196
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

197

198
@CustomOp.register("gelu_new")
199
class NewGELU(CustomOp):
200

201
202
203
204
205
206
207
208
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

209
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
210
211
212
213
214
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

215
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
216
        out = torch.empty_like(x)
217
        self.op(out, x)
218
219
        return out

220
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
221
        return self.op(x)
222

223

224
@CustomOp.register("gelu_fast")
225
class FastGELU(CustomOp):
226

227
228
229
230
231
232
233
234
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

235
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
236
237
238
239
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

240
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
241
        out = torch.empty_like(x)
242
        self.op(out, x)
243
244
        return out

245
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
246
        return self.op(x)
247

248

249
@CustomOp.register("quick_gelu")
250
251
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
252
253
254
255
256
257
258
259
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

260
261
262
263
264
265
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
266
        self.op(out, x)
267
268
        return out

269
270
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
271
        self.op(out, x)
272
273
        return out

274
275
276
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

277

278
@CustomOp.register("relu2")
279
280
281
282
283
284
285
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
286
        return torch.square(F.relu(x))
287
288
289
290
291

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


292
293
294
295
296
297
298
299
300
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
301
302
303
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
304
305
306
    ):
        super().__init__()
        self.act = act_module
307
        self.input_is_parallel = input_is_parallel
308
309
310
311
312
313
314
315
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
316
        self.scales = nn.Parameter(
317
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
318
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
319

320
    def forward(self, x: torch.Tensor) -> torch.Tensor:
321
322
        return self.act(x) / self.scales

323
324
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
325
326
327
328
329
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
330
331
332
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

333

334
335
336
337
338
339
340
341
342
343
344
345
346
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
347
348
    "silu":
    lambda: nn.SiLU(),
349
350
351
    "quick_gelu":
    lambda: QuickGELU(),
})
352
353


354
def get_act_fn(act_fn_name: str) -> nn.Module:
355
    """Get an activation function by name."""
356
357
358
359
360
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

361
    return _ACTIVATION_REGISTRY[act_fn_name]
362
363
364
365
366


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
367
    "gelu_and_mul": lambda: GeluAndMul(),
368
369
370
})


371
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
372
373
374
375
376
377
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

378
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]