activation.py 11.7 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.platforms import current_platform
14
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16


17
@CustomOp.register("fatrelu_and_mul")
18
19
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
20

21
22
23
24
25
26
27
28
29
30
31
32
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
33
34
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.fatrelu_and_mul
35
36
37
38
39
40
41
42
43

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
44
45
46
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
47
        self.op(out, x, self.threshold)
48
        return out
49
50


51
@CustomOp.register("silu_and_mul")
52
class SiluAndMul(CustomOp):
53
54
    """An activation function for SwiGLU.

55
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
    Shapes:
58
59
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
60
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
63
    def __init__(self):
        super().__init__()
64
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
65
66
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
67
68
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
69

70
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
71
72
73
74
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

75
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
76
77
78
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
79
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        return out
81

82
83
84
85
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
86
        self.op(out, x)
87
88
        return out

89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


125
@CustomOp.register("gelu_and_mul")
126
class GeluAndMul(CustomOp):
127
128
129
130
131
132
133
134
135
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

136
137
138
139
140
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
141
142
143
144
145
146
147
148
149
150
151
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
152

153
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
154
155
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
156
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
157

158
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
159
160
161
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
162
        self.op(out, x)
163
164
        return out

165
166
167
168
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
169
        self.op(out, x)
170
171
        return out

172
173
174
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

175

176
@CustomOp.register("gelu_new")
177
class NewGELU(CustomOp):
178

179
180
181
182
183
184
185
186
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

187
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
188
189
190
191
192
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

193
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
194
        out = torch.empty_like(x)
195
        self.op(out, x)
196
197
        return out

198
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
199
        return self.op(x)
200

201

202
@CustomOp.register("gelu_fast")
203
class FastGELU(CustomOp):
204

205
206
207
208
209
210
211
212
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

213
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
214
215
216
217
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

218
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
219
        out = torch.empty_like(x)
220
        self.op(out, x)
221
222
        return out

223
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
224
        return self.op(x)
225

226

227
@CustomOp.register("quick_gelu")
228
229
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
230
231
232
233
234
235
236
237
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

238
239
240
241
242
243
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
244
        self.op(out, x)
245
246
        return out

247
248
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
249
        self.op(out, x)
250
251
        return out

252
253
254
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

255

256
@CustomOp.register("relu2")
257
258
259
260
261
262
263
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
264
        return torch.square(F.relu(x))
265
266
267
268
269

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


270
271
272
273
274
275
276
277
278
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
279
280
281
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
282
283
284
    ):
        super().__init__()
        self.act = act_module
285
        self.input_is_parallel = input_is_parallel
286
287
288
289
290
291
292
293
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
294
        self.scales = nn.Parameter(
295
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
296
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
297

298
    def forward(self, x: torch.Tensor) -> torch.Tensor:
299
300
        return self.act(x) / self.scales

301
302
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
303
304
305
306
307
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
308
309
310
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

311

312
313
314
315
316
317
318
319
320
321
322
323
324
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
325
326
    "silu":
    lambda: nn.SiLU(),
327
328
329
    "quick_gelu":
    lambda: QuickGELU(),
})
330
331


332
def get_act_fn(act_fn_name: str) -> nn.Module:
333
    """Get an activation function by name."""
334
335
336
337
338
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

339
    return _ACTIVATION_REGISTRY[act_fn_name]
340
341
342
343
344
345
346
347


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


348
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
349
350
351
352
353
354
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

355
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]