activation.py 1.17 KB
Newer Older
1
"""Custom activation functions."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
4
5
6
import torch
import torch.nn as nn

from cacheflow import activation_ops

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_new": nn.GELU(approximate="tanh"),   # NOTE: This may introduce small rounding errors.
    "gelu_fast": nn.GELU(approximate="tanh"),  # NOTE: This may introduce small rounding errors.
    "relu": nn.ReLU(),
}


def get_act_fn(act_fn: str) -> nn.Module:
    """Get an activation function by name."""
    act_fn = act_fn.lower()
    if act_fn in _ACTIVATION_REGISTRY:
        return _ACTIVATION_REGISTRY[act_fn]
    raise ValueError(f"Activation function {act_fn!r} is not supported.")

Woosuk Kwon's avatar
Woosuk Kwon committed
22
23

class SiluAndMul(nn.Module):
24
25
26
27
    """An activation function for SwiGLU.

    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
38
39
40

    def __init__(self):
        super().__init__()

    def forward(
        self,
        x: torch.Tensor,        # (num_tokens, 2 * d)
    ) -> torch.Tensor:          # (num_tokens, d)
        num_tokens = x.shape[0]
        d = x.shape[1] // 2
        out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
        activation_ops.silu_and_mul(out, x)
        return out