test_activation.py 3.15 KB
Newer Older
1
2
from typing import Type

3
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import torch
5

6
from tests.kernels.utils import opcheck
7
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
8
9
                                                   NewGELU, QuickGELU,
                                                   SiluAndMul)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from .allclose_default import get_default_atol, get_default_rtol

13
14
15
16
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
D = [512, 4096, 5120, 13824]  # Arbitrary values for testing
SEEDS = [0]
17
18
19
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
20

Woosuk Kwon's avatar
Woosuk Kwon committed
21

22
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
23
24
25
26
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
27
@pytest.mark.parametrize("device", CUDA_DEVICES)
Woosuk Kwon's avatar
Woosuk Kwon committed
28
@torch.inference_mode()
29
def test_act_and_mul(
30
    activation: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
34
    seed: int,
35
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
) -> None:
37
    torch.random.manual_seed(seed)
38
39
40
41
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, 2 * d, dtype=dtype)
42
43
    if activation == "silu":
        layer = SiluAndMul()
44
        fn = torch.ops._C.silu_and_mul
45
46
    elif activation == "gelu":
        layer = GeluAndMul(approximate="none")
47
        fn = torch.ops._C.gelu_and_mul
48
49
    elif activation == "gelu_tanh":
        layer = GeluAndMul(approximate="tanh")
50
        fn = torch.ops._C.gelu_tanh_and_mul
51
    out = layer(x)
52
    ref_out = layer.forward_native(x)
53
54
    # The SiLU and GELU implementations are equivalent to the native PyTorch
    # implementations, so we can do exact comparison.
55
    torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
58
59
60
    d = x.shape[-1] // 2
    output_shape = (x.shape[:-1] + (d, ))
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
    opcheck(fn, (out, x))
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
63
64
65

@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
                                        (NewGELU, torch.ops._C.gelu_new),
                                        (QuickGELU, torch.ops._C.gelu_quick)])
66
67
68
69
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
70
@pytest.mark.parametrize("device", CUDA_DEVICES)
71
@torch.inference_mode()
72
73
def test_activation(
    activation: Type[torch.nn.Module],
74
75
76
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
77
    seed: int,
78
    device: str,
79
) -> None:
80
    torch.random.manual_seed(seed)
81
82
83
84
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.set_default_device(device)
    x = torch.randn(num_tokens, d, dtype=dtype)
85
86
    layer = activation[0]()
    fn = activation[1]
87
    out = layer(x)
88
    ref_out = layer.forward_native(x)
89
90
91
92
    torch.testing.assert_close(out,
                               ref_out,
                               atol=get_default_atol(out),
                               rtol=get_default_rtol(out))
93
94
95

    out = torch.empty_like(x)
    opcheck(fn, (out, x))