activation.py 1.45 KB
Newer Older
1
"""Custom activation functions."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
5
from vllm import activation_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8


class SiluAndMul(nn.Module):
9
10
    """An activation function for SwiGLU.

11
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
    Shapes:
14
15
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
16
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
    def forward(self, x: torch.Tensor) -> torch.Tensor:
19
20
21
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
        activation_ops.silu_and_mul(out, x)
        return out
24
25
26
27
28


class NewGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
29
        out = torch.empty_like(x)
30
31
32
33
34
35
36
        activation_ops.gelu_new(out, x)
        return out


class FastGELU(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
37
        out = torch.empty_like(x)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        activation_ops.gelu_fast(out, x)
        return out


_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_fast": FastGELU(),
    "gelu_new": NewGELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "relu": nn.ReLU(),
}


def get_act_fn(act_fn: str) -> nn.Module:
    """Get an activation function by name."""
    act_fn = act_fn.lower()
    if act_fn in _ACTIVATION_REGISTRY:
        return _ACTIVATION_REGISTRY[act_fn]
    raise ValueError(f"Activation function {act_fn!r} is not supported.")