activation.py 10.8 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.utils import set_weight_attrs
14
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16


17
@CustomOp.register("fatrelu_and_mul")
18
19
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
42
43
44
45
46
47
48
        from vllm import _custom_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.fatrelu_and_mul(out, x, self.threshold)
        return out
49
50


51
@CustomOp.register("silu_and_mul")
52
class SiluAndMul(CustomOp):
53
54
    """An activation function for SwiGLU.

55
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
    Shapes:
58
59
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
60
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
63
64
65
66
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

67
68
69
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

70
71
72
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
73
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        return out
75

76
77
78
79
80
81
82
83
84
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

85

86
@CustomOp.register("gelu_and_mul")
87
class GeluAndMul(CustomOp):
88
89
90
91
92
93
94
95
96
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

97
98
99
100
101
102
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

103
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
104
105
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
106
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
107

108
109
110
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

111
112
113
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
114
115
116
117
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
118
119
        return out

120
121
122
123
124
125
126
127
128
129
130
131
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

132
133
134
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

135

136
@CustomOp.register("gelu_new")
137
class NewGELU(CustomOp):
138

139
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
140
141
142
143
144
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

145
146
147
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

148
        out = torch.empty_like(x)
149
        ops.gelu_new(out, x)
150
151
        return out

152
153
154
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

155
        return ops.gelu_new(x)
156

157

158
@CustomOp.register("gelu_fast")
159
class FastGELU(CustomOp):
160

161
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
162
163
164
165
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

166
167
168
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

169
        out = torch.empty_like(x)
170
        ops.gelu_fast(out, x)
171
172
        return out

173
174
175
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

176
        return ops.gelu_fast(x)
177

178

179
@CustomOp.register("quick_gelu")
180
181
182
183
184
185
186
187
188
189
190
191
192
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

193
194
195
196
197
198
199
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

200
201
202
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

203

204
@CustomOp.register("relu2")
205
206
207
208
209
210
211
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
212
        return torch.square(F.relu(x))
213
214
215
216
217

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


218
219
220
221
222
223
224
225
226
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
227
228
229
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
230
231
232
    ):
        super().__init__()
        self.act = act_module
233
        self.input_is_parallel = input_is_parallel
234
235
236
237
238
239
240
241
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
242
        self.scales = nn.Parameter(
243
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
244
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
245

246
    def forward(self, x: torch.Tensor) -> torch.Tensor:
247
248
        return self.act(x) / self.scales

249
250
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
251
252
253
254
255
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
256
257
258
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

259

260
261
262
263
264
265
266
267
268
269
270
271
272
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
273
274
    "silu":
    lambda: nn.SiLU(),
275
276
277
    "quick_gelu":
    lambda: QuickGELU(),
})
278
279


280
281
282
283
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
284
285
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
286
) -> nn.Module:
287
    """Get an activation function by name."""
288
289
290
291
292
293
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
294
295
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
296
297
298
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
299
300
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
301
    return act_fn
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


def get_act_and_mul_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
    return act_fn