activation.py 2.69 KB
Newer Older
1
"""Custom activation functions."""
2
3
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm import activation_ops
8
from vllm.model_executor.layers.quantization import QuantizationConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
11


class SiluAndMul(nn.Module):
12
13
    """An activation function for SwiGLU.

14
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
    Shapes:
17
18
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
19
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
20

21
    def forward(self, x: torch.Tensor) -> torch.Tensor:
22
23
24
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
        activation_ops.silu_and_mul(out, x)
        return out
27
28
29
30
31


class NewGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
32
        out = torch.empty_like(x)
33
34
35
36
37
38
39
        activation_ops.gelu_new(out, x)
        return out


class FastGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
40
        out = torch.empty_like(x)
41
42
43
44
        activation_ops.gelu_fast(out, x)
        return out


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
        hidden_size: int,
        params_dtype: torch.dtype,
    ):
        super().__init__()
        self.act = act_module
        self.scales = nn.Parameter(
            torch.empty(hidden_size, dtype=params_dtype, device="cuda"))

    def forward(self, x: torch.Tensor):
        return self.act(x) / self.scales


66
67
68
69
70
71
72
73
74
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
}


75
76
77
78
79
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
) -> nn.Module:
80
    """Get an activation function by name."""
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
    if quant_config is not None:
        if act_fn_name in quant_config.get_scaled_act_names():
            if intermediate_size is None:
                raise ValueError(
                    "intermediate_size must be specified for scaled "
                    "activation functions.")
            return ScaledActivation(
                act_fn,
                intermediate_size,
                params_dtype=torch.get_default_dtype(),
            )
    return act_fn