"vllm/entrypoints/openai/completion/serving.py" did not exist on "4d01b6428448225807e6605d04e37e29fe729b44"
activation.py 12.2 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.platforms import current_platform
14
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
15
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17


18
@CustomOp.register("fatrelu_and_mul")
19
20
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
21

22
23
24
25
26
27
28
29
30
31
32
33
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
34
        if current_platform.is_cuda_alike():
35
            self.op = torch.ops._C.fatrelu_and_mul
36
37
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
38
39
40
41
42
43
44
45
46

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
47
48
49
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
50
        self.op(out, x, self.threshold)
51
        return out
52
53


54
@CustomOp.register("silu_and_mul")
55
class SiluAndMul(CustomOp):
56
57
    """An activation function for SwiGLU.

58
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
    Shapes:
61
62
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
63
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
66
    def __init__(self):
        super().__init__()
67
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
68
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
69
            self.op_opt = torch.ops._C.silu_and_mul_opt
70
        elif current_platform.is_xpu():
71
72
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
73

74
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
75
76
77
78
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

79
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
80
81
82
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
83
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
84
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
85
        else:
zhuwenwen's avatar
zhuwenwen committed
86
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        return out
88

89
90
91
92
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
93
        self.op(out, x)
94
95
        return out

96

97
98
99
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
100

101
102
103
104
105
106
107
108
109
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
110
        if current_platform.is_cuda_alike():
111
112
113
114
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
115
116
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
117
118
119
120
121
122
123

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
124
125
126
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
127
        self.op(out, x)
128
129
        return out

130
131
132
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

133

134
@CustomOp.register("gelu_and_mul")
135
class GeluAndMul(CustomOp):
136
137
138
139
140
141
142
143
144
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

145
146
147
148
149
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
150
151
152
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
153
                self.op_opt = torch.ops._C.gelu_and_mul_opt
154
155
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
156
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
157
158
159
160
161
162
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
163

164
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
165
166
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
167
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
168

169
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
170
171
172
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
173
174
175
176
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
177
178
        return out

179
180
181
182
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
183
        self.op(out, x)
184
185
        return out

186
187
188
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

189

190
@CustomOp.register("gelu_new")
191
class NewGELU(CustomOp):
192

193
194
195
196
197
198
199
200
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

201
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
202
203
204
205
206
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

207
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
208
        out = torch.empty_like(x)
209
        self.op(out, x)
210
211
        return out

212
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
213
        return self.op(x)
214

215

216
@CustomOp.register("gelu_fast")
217
class FastGELU(CustomOp):
218

219
220
221
222
223
224
225
226
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

227
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
228
229
230
231
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

232
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
233
        out = torch.empty_like(x)
234
        self.op(out, x)
235
236
        return out

237
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
238
        return self.op(x)
239

240

241
@CustomOp.register("quick_gelu")
242
243
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
244
245
246
247
248
249
250
251
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

252
253
254
255
256
257
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
258
        self.op(out, x)
259
260
        return out

261
262
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
263
        self.op(out, x)
264
265
        return out

266
267
268
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

269

270
@CustomOp.register("relu2")
271
272
273
274
275
276
277
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
278
        return torch.square(F.relu(x))
279
280
281
282
283

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


284
285
286
287
288
289
290
291
292
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
293
294
295
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
296
297
298
    ):
        super().__init__()
        self.act = act_module
299
        self.input_is_parallel = input_is_parallel
300
301
302
303
304
305
306
307
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
308
        self.scales = nn.Parameter(
309
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
310
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
311

312
    def forward(self, x: torch.Tensor) -> torch.Tensor:
313
314
        return self.act(x) / self.scales

315
316
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
317
318
319
320
321
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
322
323
324
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

325

326
327
328
329
330
331
332
333
334
335
336
337
338
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
339
340
    "silu":
    lambda: nn.SiLU(),
341
342
343
    "quick_gelu":
    lambda: QuickGELU(),
})
344
345


346
def get_act_fn(act_fn_name: str) -> nn.Module:
347
    """Get an activation function by name."""
348
349
350
351
352
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

353
    return _ACTIVATION_REGISTRY[act_fn_name]
354
355
356
357
358
359
360
361


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


362
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
363
364
365
366
367
368
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

369
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]