activation.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Custom activation functions."""
3
import math
4
5
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
12
from vllm.model_executor.custom_op import CustomOp
13
from vllm.model_executor.utils import set_weight_attrs
14
from vllm.platforms import current_platform
15
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17


18
@CustomOp.register("fatrelu_and_mul")
19
20
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
21

22
23
24
25
26
27
28
29
30
31
32
33
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
34
        if current_platform.is_cuda_alike():
35
            self.op = torch.ops._C.fatrelu_and_mul
36
37
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
38
39
40
41
42
43
44
45
46

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
47
48
49
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
50
        self.op(out, x, self.threshold)
51
        return out
52
53


54
@CustomOp.register("silu_and_mul")
55
class SiluAndMul(CustomOp):
56
57
    """An activation function for SwiGLU.

58
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
    Shapes:
61
62
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
63
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
66
    def __init__(self):
        super().__init__()
67
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
68
69
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
70
71
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
72

73
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
74
75
76
77
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

78
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
79
80
81
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
82
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        return out
84

85
86
87
88
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
89
        self.op(out, x)
90
91
        return out

92
93
94
95
96
97
98
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

99

100
101
102
103
104
105
106
107
108
109
110
111
112
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
113
        if current_platform.is_cuda_alike():
114
115
116
117
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
118
119
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


137
@CustomOp.register("gelu_and_mul")
138
class GeluAndMul(CustomOp):
139
140
141
142
143
144
145
146
147
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

148
149
150
151
152
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
153
154
155
156
157
158
159
160
161
162
163
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
164

165
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
166
167
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
168
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
169

170
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
171
172
173
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
174
        self.op(out, x)
175
176
        return out

177
178
179
180
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
181
        self.op(out, x)
182
183
        return out

184
185
186
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

187

188
@CustomOp.register("gelu_new")
189
class NewGELU(CustomOp):
190

191
192
193
194
195
196
197
198
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

199
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
200
201
202
203
204
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

205
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
206
        out = torch.empty_like(x)
207
        self.op(out, x)
208
209
        return out

210
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
211
        return self.op(x)
212

213

214
@CustomOp.register("gelu_fast")
215
class FastGELU(CustomOp):
216

217
218
219
220
221
222
223
224
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

225
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
226
227
228
229
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

230
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
231
        out = torch.empty_like(x)
232
        self.op(out, x)
233
234
        return out

235
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
236
        return self.op(x)
237

238

239
@CustomOp.register("quick_gelu")
240
241
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
242
243
244
245
246
247
248
249
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

250
251
252
253
254
255
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
256
        self.op(out, x)
257
258
        return out

259
260
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
261
        self.op(out, x)
262
263
        return out

264
265
266
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

267

268
@CustomOp.register("relu2")
269
270
271
272
273
274
275
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
276
        return torch.square(F.relu(x))
277
278
279
280
281

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


282
283
284
285
286
287
288
289
290
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
291
292
293
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
294
295
296
    ):
        super().__init__()
        self.act = act_module
297
        self.input_is_parallel = input_is_parallel
298
299
300
301
302
303
304
305
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
306
        self.scales = nn.Parameter(
307
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
308
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
309

310
    def forward(self, x: torch.Tensor) -> torch.Tensor:
311
312
        return self.act(x) / self.scales

313
314
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
315
316
317
318
319
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
320
321
322
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

323

324
325
326
327
328
329
330
331
332
333
334
335
336
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
337
338
    "silu":
    lambda: nn.SiLU(),
339
340
341
    "quick_gelu":
    lambda: QuickGELU(),
})
342
343


344
def get_act_fn(act_fn_name: str) -> nn.Module:
345
    """Get an activation function by name."""
346
347
348
349
350
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

351
    return _ACTIVATION_REGISTRY[act_fn_name]
352
353
354
355
356


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
357
    "geglu": lambda: GeluAndMul(),
358
359
360
})


361
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
362
363
364
365
366
367
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

368
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]