activation.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18


19
@CustomOp.register("fatrelu_and_mul")
20
21
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
22

23
24
25
26
27
28
29
30
31
32
33
34
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
35
        if current_platform.is_cuda_alike():
36
            self.op = torch.ops._C.fatrelu_and_mul
37
38
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
39
40
41
42
43
44
45
46
47

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
48
49
50
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
51
        self.op(out, x, self.threshold)
52
        return out
53
54


55
@CustomOp.register("silu_and_mul")
56
class SiluAndMul(CustomOp):
57
58
    """An activation function for SwiGLU.

59
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
    Shapes:
62
63
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
64
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
    def __init__(self):
        super().__init__()
68
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
69
70
            self.op = torch.ops._C.silu_and_mul
        elif current_platform.is_xpu():
71
72
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
73

74
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
75
76
77
78
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

79
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
80
81
82
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
83
        self.op(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        return out
85

86
87
88
89
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
90
        self.op(out, x)
91
92
        return out

93
94
95
96
97
98
99
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

100

101
102
103
104
105
106
107
108
109
110
111
112
113
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.

    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
114
        if current_platform.is_cuda_alike():
115
116
117
118
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
119
120
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        self.op(out, x)
        return out

    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


Robert Shaw's avatar
Robert Shaw committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


189
@CustomOp.register("gelu_and_mul")
190
class GeluAndMul(CustomOp):
191
192
193
194
195
196
197
198
199
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

200
201
202
203
204
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
205
206
207
208
209
210
211
212
213
214
215
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
216

217
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
218
219
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
220
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
221

222
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
223
224
225
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
226
        self.op(out, x)
227
228
        return out

229
230
231
232
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
233
        self.op(out, x)
234
235
        return out

236
237
238
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

239

240
@CustomOp.register("gelu_new")
241
class NewGELU(CustomOp):
242

243
244
245
246
247
248
249
250
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

251
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
252
253
254
255
256
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

257
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
258
        out = torch.empty_like(x)
259
        self.op(out, x)
260
261
        return out

262
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
263
        return self.op(x)
264

265

266
@CustomOp.register("gelu_fast")
267
class FastGELU(CustomOp):
268

269
270
271
272
273
274
275
276
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

277
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
278
279
280
281
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

282
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
283
        out = torch.empty_like(x)
284
        self.op(out, x)
285
286
        return out

287
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
288
        return self.op(x)
289

290

291
@CustomOp.register("quick_gelu")
292
293
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
294
295
296
297
298
299
300
301
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

302
303
304
305
306
307
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
308
        self.op(out, x)
309
310
        return out

311
312
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
313
        self.op(out, x)
314
315
        return out

316
317
318
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

319

320
@CustomOp.register("relu2")
321
322
323
324
325
326
327
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
328
        return torch.square(F.relu(x))
329
330
331
332
333

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


334
335
336
337
338
339
340
341
342
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
343
344
345
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
346
347
348
    ):
        super().__init__()
        self.act = act_module
349
        self.input_is_parallel = input_is_parallel
350
351
352
353
354
355
356
357
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
358
        self.scales = nn.Parameter(
359
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
360
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
361

362
    def forward(self, x: torch.Tensor) -> torch.Tensor:
363
364
        return self.act(x) / self.scales

365
366
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
367
368
369
370
371
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
372
373
374
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

375

376
377
378
379
380
381
382
383
384
385
386
387
388
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
389
390
    "silu":
    lambda: nn.SiLU(),
391
392
393
    "quick_gelu":
    lambda: QuickGELU(),
})
394
395


396
def get_act_fn(act_fn_name: str) -> nn.Module:
397
    """Get an activation function by name."""
398
399
400
401
402
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

403
    return _ACTIVATION_REGISTRY[act_fn_name]
404
405
406
407
408


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
409
    "geglu": lambda: GeluAndMul(),
410
411
412
})


413
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
414
415
416
417
418
419
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

420
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]