test_activation.py 2.49 KB
Newer Older
1
2
from typing import Type

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch
5

6
7
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
                                                   NewGELU, SiluAndMul)
8
from allclose_default import get_default_atol, get_default_rtol
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
12
13
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
D = [512, 4096, 5120, 13824]  # Arbitrary values for testing
SEEDS = [0]
14
15
16
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
20
21
22
23
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
24
@pytest.mark.parametrize("device", CUDA_DEVICES)
Woosuk Kwon's avatar
Woosuk Kwon committed
25
@torch.inference_mode()
26
def test_act_and_mul(
27
    activation: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
31
    seed: int,
32
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
) -> None:
34
    torch.random.manual_seed(seed)
35
36
37
38
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, 2 * d, dtype=dtype)
39
40
41
42
43
44
    if activation == "silu":
        layer = SiluAndMul()
    elif activation == "gelu":
        layer = GeluAndMul(approximate="none")
    elif activation == "gelu_tanh":
        layer = GeluAndMul(approximate="tanh")
45
46
    out = layer(x)
    ref_out = layer._forward(x)
47
48
49
    # The SiLU and GELU implementations are equivalent to the native PyTorch
    # implementations, so we can do exact comparison.
    assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51


52
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
53
54
55
56
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
57
@pytest.mark.parametrize("device", CUDA_DEVICES)
58
@torch.inference_mode()
59
60
def test_activation(
    activation: Type[torch.nn.Module],
61
62
63
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
64
    seed: int,
65
    device: str,
66
) -> None:
67
    torch.random.manual_seed(seed)
68
69
70
71
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, d, dtype=dtype)
72
    layer = activation()
73
74
    out = layer(x)
    ref_out = layer._forward(x)
75
76
77
78
    assert torch.allclose(out,
                          ref_out,
                          atol=get_default_atol(out),
                          rtol=get_default_rtol(out))