activation.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
6
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
import torch
import torch.nn as nn
9
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.custom_op import CustomOp
14
from vllm.model_executor.utils import set_weight_attrs
15
from vllm.platforms import current_platform
16
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
17
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19


20
@CustomOp.register("fatrelu_and_mul")
21
22
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
23

24
25
26
27
28
29
30
31
32
33
34
35
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
36
        if current_platform.is_cuda_alike():
37
            self.op = torch.ops._C.fatrelu_and_mul
38
39
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
40
41
42
43
44
45
46
47
48

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
49
50
51
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
52
        self.op(out, x, self.threshold)
53
        return out
54
55


56
@CustomOp.register("silu_and_mul")
57
class SiluAndMul(CustomOp):
58
59
    """An activation function for SwiGLU.

60
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
    Shapes:
63
64
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
65
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
68
    def __init__(self):
        super().__init__()
69
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
70
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
71
            self.op_opt = torch.ops._C.silu_and_mul_opt
72
        elif current_platform.is_xpu():
73
74
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
75

76
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
77
        """PyTorch-native implementation equivalent to forward()."""
maxiao1's avatar
maxiao1 committed
78
79
80
81
82
        if not torch.compiler.is_compiling():  # 非 capture 阶段
            return self.forward_cuda(x)  # 强制走 fused kernel
        else:
            d = x.shape[-1] // 2
            return F.silu(x[..., :d]) * x[..., d:]
83

84
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
85
86
87
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
88
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
89
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
90
        else:
zhuwenwen's avatar
zhuwenwen committed
91
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        return out
93

94
95
96
97
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
98
        self.op(out, x)
99
100
        return out

101
102
103
104
105
106
107
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

108

109
110
111
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
112

113
114
115
116
117
118
119
120
121
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
122
        if current_platform.is_cuda_alike():
123
124
125
126
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
127
128
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
129
130
131
132
133
134
135

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
136
137
138
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
139
        self.op(out, x)
140
141
        return out

142
143
144
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

145

Robert Shaw's avatar
Robert Shaw committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


197
@CustomOp.register("gelu_and_mul")
198
class GeluAndMul(CustomOp):
199
200
201
202
203
204
205
206
207
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

208
209
210
211
212
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
213
214
215
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
216
                self.op_opt = torch.ops._C.gelu_and_mul_opt
217
218
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
219
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
220
221
222
223
224
225
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
226

227
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
228
229
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
230
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
231

232
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
233
234
235
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
236
237
238
239
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
240
241
        return out

242
243
244
245
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
246
        self.op(out, x)
247
248
        return out

249
250
251
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

252

253
@CustomOp.register("gelu_new")
254
class NewGELU(CustomOp):
255

256
257
258
259
260
261
262
263
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

264
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
265
266
267
268
269
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

270
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
271
        out = torch.empty_like(x)
272
        self.op(out, x)
273
274
        return out

275
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
276
        return self.op(x)
277

278

279
@CustomOp.register("gelu_fast")
280
class FastGELU(CustomOp):
281

282
283
284
285
286
287
288
289
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

290
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
291
292
293
294
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

295
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
296
        out = torch.empty_like(x)
297
        self.op(out, x)
298
299
        return out

300
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
301
        return self.op(x)
302

303

304
@CustomOp.register("quick_gelu")
305
306
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
307
308
309
310
311
312
313
314
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

315
316
317
318
319
320
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
321
        self.op(out, x)
322
323
        return out

324
325
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
326
        self.op(out, x)
327
328
        return out

329
330
331
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

332

333
@CustomOp.register("relu2")
334
335
336
337
338
339
340
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
341
        return torch.square(F.relu(x))
342
343
344
345
346

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


347
348
349
350
351
352
353
354
355
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
356
357
358
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
359
360
361
    ):
        super().__init__()
        self.act = act_module
362
        self.input_is_parallel = input_is_parallel
363
364
365
366
367
368
369
370
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
371
        self.scales = nn.Parameter(
372
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
373
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
374

375
    def forward(self, x: torch.Tensor) -> torch.Tensor:
376
377
        return self.act(x) / self.scales

378
379
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
380
381
382
383
384
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
385
386
387
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

388

389
390
391
392
393
394
395
396
397
398
399
400
401
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
402
403
    "silu":
    lambda: nn.SiLU(),
404
405
406
    "quick_gelu":
    lambda: QuickGELU(),
})
407
408


409
def get_act_fn(act_fn_name: str) -> nn.Module:
410
    """Get an activation function by name."""
411
412
413
414
415
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

416
    return _ACTIVATION_REGISTRY[act_fn_name]
417
418
419
420
421


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
422
    "geglu": lambda: GeluAndMul(),
423
424
425
})


426
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
427
428
429
430
431
432
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

433
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]