activation.py 4.8 KB
Newer Older
1
"""Custom activation functions."""
2
import math
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
from vllm._C import ops
10
from vllm.model_executor.layers.quantization import QuantizationConfig
11
12
13
14
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17


class SiluAndMul(nn.Module):
18
19
    """An activation function for SwiGLU.

20
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
    Shapes:
23
24
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
25
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
30
31
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

32
    def forward(self, x: torch.Tensor) -> torch.Tensor:
33
34
35
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
36
        ops.silu_and_mul(out, x)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        return out
38
39
40
41


class NewGELU(nn.Module):

42
43
44
45
46
47
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

48
    def forward(self, x: torch.Tensor) -> torch.Tensor:
49
        out = torch.empty_like(x)
50
        ops.gelu_new(out, x)
51
52
53
54
55
        return out


class FastGELU(nn.Module):

56
57
58
59
60
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

61
    def forward(self, x: torch.Tensor) -> torch.Tensor:
62
        out = torch.empty_like(x)
63
        ops.gelu_fast(out, x)
64
65
66
        return out


67
68
69
70
71
72
73
74
75
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
76
77
78
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
79
80
81
    ):
        super().__init__()
        self.act = act_module
82
        self.input_is_parallel = input_is_parallel
83
84
85
86
87
88
89
90
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
91
        self.scales = nn.Parameter(
92
93
94
95
            torch.empty(intermediate_size_per_partition,
                        dtype=params_dtype,
                        device="cuda"))
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
96

97
    def forward(self, x: torch.Tensor) -> torch.Tensor:
98
99
        return self.act(x) / self.scales

100
101
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
102
103
104
105
106
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
107
108
109
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

110

111
112
113
114
115
116
117
118
119
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
}


120
121
122
123
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
124
125
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
126
) -> nn.Module:
127
    """Get an activation function by name."""
128
129
130
131
132
133
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
134
135
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
136
137
138
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
139
140
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
141
    return act_fn