activation.py 5.48 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm._C import ops
10
from vllm.model_executor.layers.quantization import QuantizationConfig
11
12
13
14
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17


class SiluAndMul(nn.Module):
18
19
    """An activation function for SwiGLU.

20
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
    Shapes:
23
24
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
25
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
30
31
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

32
    def forward(self, x: torch.Tensor) -> torch.Tensor:
33
34
35
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
36
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        return out
38
39


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class GeluAndMul(nn.Module):
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.gelu(x[..., :d]) * x[..., d:]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        ops.gelu_and_mul(out, x)
        return out


63
64
class NewGELU(nn.Module):

65
66
67
68
69
70
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

71
    def forward(self, x: torch.Tensor) -> torch.Tensor:
72
        out = torch.empty_like(x)
73
        ops.gelu_new(out, x)
74
75
76
77
78
        return out


class FastGELU(nn.Module):

79
80
81
82
83
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

84
    def forward(self, x: torch.Tensor) -> torch.Tensor:
85
        out = torch.empty_like(x)
86
        ops.gelu_fast(out, x)
87
88
89
        return out


90
91
92
93
94
95
96
97
98
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
99
100
101
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
102
103
104
    ):
        super().__init__()
        self.act = act_module
105
        self.input_is_parallel = input_is_parallel
106
107
108
109
110
111
112
113
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
114
        self.scales = nn.Parameter(
115
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
116
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
117

118
    def forward(self, x: torch.Tensor) -> torch.Tensor:
119
120
        return self.act(x) / self.scales

121
122
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
123
124
125
126
127
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
128
129
130
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

131

132
133
134
135
136
137
138
139
140
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
}


141
142
143
144
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
145
146
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
147
) -> nn.Module:
148
    """Get an activation function by name."""
149
150
151
152
153
154
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
155
156
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
157
158
159
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
160
161
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
162
    return act_fn