activation.py 9.57 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
10
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
11
from vllm.model_executor.custom_op import CustomOp
12
from vllm.model_executor.utils import set_weight_attrs
13
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15


16
@CustomOp.register("fatrelu_and_mul")
17
18
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
41
42
43
44
45
46
47
        from vllm import _custom_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.fatrelu_and_mul(out, x, self.threshold)
        return out
48
49


50
@CustomOp.register("silu_and_mul")
51
class SiluAndMul(CustomOp):
52
53
    """An activation function for SwiGLU.

54
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
    Shapes:
57
58
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
59
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
62
63
64
65
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

66
67
68
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

69
70
71
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
72
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        return out
74

75
76
77
78
79
80
81
82
83
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.silu_and_mul(out, x)
        return out

84

85
@CustomOp.register("gelu_and_mul")
86
class GeluAndMul(CustomOp):
87
88
89
90
91
92
93
94
95
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

96
97
98
99
100
101
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

102
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
103
104
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
105
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
106

107
108
109
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

110
111
112
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
113
114
115
116
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
117
118
        return out

119
120
121
122
123
124
125
126
127
128
129
130
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "none":
            ops.gelu_and_mul(out, x)
        elif self.approximate == "tanh":
            ops.gelu_tanh_and_mul(out, x)
        return out

131
132
133
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

134

135
@CustomOp.register("gelu_new")
136
class NewGELU(CustomOp):
137

138
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
139
140
141
142
143
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

144
145
146
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

147
        out = torch.empty_like(x)
148
        ops.gelu_new(out, x)
149
150
        return out

151
152
153
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

154
        return ops.gelu_new(x)
155

156

157
@CustomOp.register("gelu_fast")
158
class FastGELU(CustomOp):
159

160
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
161
162
163
164
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

165
166
167
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

168
        out = torch.empty_like(x)
169
        ops.gelu_fast(out, x)
170
171
        return out

172
173
174
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

175
        return ops.gelu_fast(x)
176

177

178
@CustomOp.register("quick_gelu")
179
180
181
182
183
184
185
186
187
188
189
190
191
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        from vllm import _custom_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

192
193
194
195
196
197
198
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        from vllm._ipex_ops import ipex_ops as ops

        out = torch.empty_like(x)
        ops.gelu_quick(out, x)
        return out

199
200
201
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

202

203
@CustomOp.register("relu2")
204
205
206
207
208
209
210
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
211
        return torch.square(F.relu(x))
212
213
214
215
216

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


217
218
219
220
221
222
223
224
225
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
226
227
228
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
229
230
231
    ):
        super().__init__()
        self.act = act_module
232
        self.input_is_parallel = input_is_parallel
233
234
235
236
237
238
239
240
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
241
        self.scales = nn.Parameter(
242
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
243
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
244

245
    def forward(self, x: torch.Tensor) -> torch.Tensor:
246
247
        return self.act(x) / self.scales

248
249
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
250
251
252
253
254
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
255
256
257
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

258

259
260
261
262
263
264
265
266
267
268
269
270
271
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
272
273
    "silu":
    lambda: nn.SiLU(),
274
275
276
    "quick_gelu":
    lambda: QuickGELU(),
})
277
278


279
def get_act_fn(act_fn_name: str) -> nn.Module:
280
    """Get an activation function by name."""
281
282
283
284
285
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

286
    return _ACTIVATION_REGISTRY[act_fn_name]
287
288
289
290
291
292
293
294


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
})


295
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
296
297
298
299
300
301
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

302
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]