activation.py 3.96 KB
Newer Older
1
"""Custom activation functions."""
2
3
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm import activation_ops
8
from vllm.model_executor.layers.quantization import QuantizationConfig
9
10
11
12
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
14
15


class SiluAndMul(nn.Module):
16
17
    """An activation function for SwiGLU.

18
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
19

20
    Shapes:
21
22
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
23
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
    def forward(self, x: torch.Tensor) -> torch.Tensor:
26
27
28
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
        activation_ops.silu_and_mul(out, x)
        return out
31
32
33
34
35


class NewGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
36
        out = torch.empty_like(x)
37
38
39
40
41
42
43
        activation_ops.gelu_new(out, x)
        return out


class FastGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
44
        out = torch.empty_like(x)
45
46
47
48
        activation_ops.gelu_fast(out, x)
        return out


49
50
51
52
53
54
55
56
57
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
58
59
60
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
61
62
63
    ):
        super().__init__()
        self.act = act_module
64
65
66
67
68
69
70
71
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
72
        self.scales = nn.Parameter(
73
74
75
76
            torch.empty(intermediate_size_per_partition,
                        dtype=params_dtype,
                        device="cuda"))
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
77

78
    def forward(self, x: torch.Tensor) -> torch.Tensor:
79
80
        return self.act(x) / self.scales

81
82
83
84
85
86
87
88
89
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = param_data.shape[0]
        start_idx = tp_rank * shard_size
        loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

90

91
92
93
94
95
96
97
98
99
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
}


100
101
102
103
def get_act_fn(
    act_fn_name: str,
    quant_config: Optional[QuantizationConfig] = None,
    intermediate_size: Optional[int] = None,
104
105
    input_is_parallel: bool = True,
    params_dtype: Optional[torch.dtype] = None,
106
) -> nn.Module:
107
    """Get an activation function by name."""
108
109
110
111
112
113
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
114
115
    if (quant_config is not None
            and act_fn_name in quant_config.get_scaled_act_names()):
116
117
118
        if intermediate_size is None:
            raise ValueError("intermediate_size must be specified for scaled "
                             "activation functions.")
119
120
        return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
                                params_dtype)
121
    return act_fn