activation.py 11.9 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.platforms import current_platform
14
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16


17
@CustomOp.register("fatrelu_and_mul")
18
19
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
20

21
22
23
24
25
26
27
28
29
30
31
32
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
33
        if current_platform.is_cuda_alike():
34
            self.op = torch.ops._C.fatrelu_and_mul
35
36
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
37
38
39
40
41
42
43
44
45

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
46
47
48
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
49
        self.op(out, x, self.threshold)
50
        return out
51
52


53
@CustomOp.register("silu_and_mul")
54
class SiluAndMul(CustomOp):
55
56
    """An activation function for SwiGLU.

57
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
    Shapes:
60
61
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
62
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
63

64
65
    def __init__(self):
        super().__init__()
66
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
67
68
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
69
70
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
71

72
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
73
74
75
76
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

77
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
78
79
80
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
81
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        return out
83

84
85
86
87
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
88
        self.op(out, x)
89
90
        return out

91

92
93
94
95
96
97
98
99
100
101
102
103
104
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
105
        if current_platform.is_cuda_alike():
106
107
108
109
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
110
111
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


129
@CustomOp.register("gelu_and_mul")
130
class GeluAndMul(CustomOp):
131
132
133
134
135
136
137
138
139
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

140
141
142
143
144
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
145
146
147
148
149
150
151
152
153
154
155
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
156

157
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
158
159
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
160
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
161

162
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
163
164
165
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
166
        self.op(out, x)
167
168
        return out

169
170
171
172
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
173
        self.op(out, x)
174
175
        return out

176
177
178
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

179

180
@CustomOp.register("gelu_new")
181
class NewGELU(CustomOp):
182

183
184
185
186
187
188
189
190
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

191
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
192
193
194
195
196
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

197
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
198
        out = torch.empty_like(x)
199
        self.op(out, x)
200
201
        return out

202
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
203
        return self.op(x)
204

205

206
@CustomOp.register("gelu_fast")
207
class FastGELU(CustomOp):
208

209
210
211
212
213
214
215
216
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

217
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
218
219
220
221
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

222
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
223
        out = torch.empty_like(x)
224
        self.op(out, x)
225
226
        return out

227
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
228
        return self.op(x)
229

230

231
@CustomOp.register("quick_gelu")
232
233
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
234
235
236
237
238
239
240
241
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

242
243
244
245
246
247
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
248
        self.op(out, x)
249
250
        return out

251
252
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
253
        self.op(out, x)
254
255
        return out

256
257
258
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

259

260
@CustomOp.register("relu2")
261
262
263
264
265
266
267
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
268
        return torch.square(F.relu(x))
269
270
271
272
273

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


274
275
276
277
278
279
280
281
282
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
283
284
285
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
286
287
288
    ):
        super().__init__()
        self.act = act_module
289
        self.input_is_parallel = input_is_parallel
290
291
292
293
294
295
296
297
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
298
        self.scales = nn.Parameter(
299
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
300
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
301

302
    def forward(self, x: torch.Tensor) -> torch.Tensor:
303
304
        return self.act(x) / self.scales

305
306
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
307
308
309
310
311
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
312
313
314
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

315

316
317
318
319
320
321
322
323
324
325
326
327
328
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
329
330
    "silu":
    lambda: nn.SiLU(),
331
332
333
    "quick_gelu":
    lambda: QuickGELU(),
})
334
335


336
def get_act_fn(act_fn_name: str) -> nn.Module:
337
    """Get an activation function by name."""
338
339
340
341
342
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

343
    return _ACTIVATION_REGISTRY[act_fn_name]
344
345
346
347
348
349
350
351


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


352
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
353
354
355
356
357
358
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

359
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]