activation.py 9.9 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.utils import set_weight_attrs
zhuwenwen's avatar
zhuwenwen committed
13

14
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
15
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17


18
@CustomOp.register("fatrelu_and_mul")
19
20
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
43
44
45
46
47
48
49
        from vllm import _custom_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.fatrelu_and_mul(out, x, self.threshold)
        return out
50
51


52
@CustomOp.register("silu_and_mul")
53
class SiluAndMul(CustomOp):
54
55
    """An activation function for SwiGLU.

56
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
    Shapes:
59
60
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
61
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
64
65
66
67
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

68
69
70
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

71
72
73
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
74
75
76
77
        if envs.VLLM_USE_OPT_OP:
            ops.silu_and_mul_opt(out, x)
        else:
            ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        return out
79

80
81
82
83
84
85
86
87
88
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

89

90
@CustomOp.register("gelu_and_mul")
91
class GeluAndMul(CustomOp):
92
93
94
95
96
97
98
99
100
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

101
102
103
104
105
106
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

107
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
108
109
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
110
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
111

112
113
114
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

115
116
117
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
118
        if self.approximate == "none":
zhuwenwen's avatar
zhuwenwen committed
119
120
121
122
            if envs.VLLM_USE_OPT_OP:
                ops.gelu_and_mul_opt(out, x)
            else:
                ops.gelu_and_mul(out, x)
123
        elif self.approximate == "tanh":
zhuwenwen's avatar
zhuwenwen committed
124
125
126
127
            if envs.VLLM_USE_OPT_OP:
                ops.gelu_tanh_and_mul_opt(out, x)
            else:
                ops.gelu_tanh_and_mul(out, x)
128
129
        return out

130
131
132
133
134
135
136
137
138
139
140
141
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

142
143
144
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

145

146
@CustomOp.register("gelu_new")
147
class NewGELU(CustomOp):
148

149
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
150
151
152
153
154
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

155
156
157
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

158
        out = torch.empty_like(x)
159
        ops.gelu_new(out, x)
160
161
        return out

162
163
164
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

165
        return ops.gelu_new(x)
166

167

168
@CustomOp.register("gelu_fast")
169
class FastGELU(CustomOp):
170

171
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
172
173
174
175
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

176
177
178
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

179
        out = torch.empty_like(x)
180
        ops.gelu_fast(out, x)
181
182
        return out

183
184
185
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

186
        return ops.gelu_fast(x)
187

188

189
@CustomOp.register("quick_gelu")
190
191
192
193
194
195
196
197
198
199
200
201
202
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

203
204
205
206
207
208
209
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

210
211
212
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

213

214
@CustomOp.register("relu2")
215
216
217
218
219
220
221
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
222
        return torch.square(F.relu(x))
223
224
225
226
227

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


228
229
230
231
232
233
234
235
236
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
237
238
239
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
240
241
242
    ):
        super().__init__()
        self.act = act_module
243
        self.input_is_parallel = input_is_parallel
244
245
246
247
248
249
250
251
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
252
        self.scales = nn.Parameter(
253
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
254
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
255

256
    def forward(self, x: torch.Tensor) -> torch.Tensor:
257
258
        return self.act(x) / self.scales

259
260
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
261
262
263
264
265
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
266
267
268
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

269

270
271
272
273
274
275
276
277
278
279
280
281
282
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
283
284
    "silu":
    lambda: nn.SiLU(),
285
286
287
    "quick_gelu":
    lambda: QuickGELU(),
})
288
289


290
def get_act_fn(act_fn_name: str) -> nn.Module:
291
    """Get an activation function by name."""
292
293
294
295
296
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

297
    return _ACTIVATION_REGISTRY[act_fn_name]
298
299
300
301
302
303
304
305


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


306
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
307
308
309
310
311
312
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

313
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]