test_activation.py 3.58 KB
Newer Older
1
import random
2
3
from typing import Type

4
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import torch
6

7
from tests.kernels.utils import opcheck
8
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
9
10
11
                                                   GeluAndMul, MulAndSilu,
                                                   NewGELU, QuickGELU,
                                                   SiluAndMul)
12
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
15
from .allclose_default import get_default_atol, get_default_rtol

16
17
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
18
D = [512, 13824]  # Arbitrary values for testing
19
SEEDS = [0]
20
21
22
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
23

Woosuk Kwon's avatar
Woosuk Kwon committed
24

25
26
27
@pytest.mark.parametrize(
    "activation",
    ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
28
29
30
31
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
32
@pytest.mark.parametrize("device", CUDA_DEVICES)
Woosuk Kwon's avatar
Woosuk Kwon committed
33
@torch.inference_mode()
34
def test_act_and_mul(
35
    activation: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
39
    seed: int,
40
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
41
) -> None:
42
    current_platform.seed_everything(seed)
43
44
    torch.set_default_device(device)
    x = torch.randn(num_tokens, 2 * d, dtype=dtype)
45
    if activation == "silu_and_mul":
46
        layer = SiluAndMul()
47
        fn = torch.ops._C.silu_and_mul
48
49
50
    if activation == "mul_and_silu":
        layer = MulAndSilu()
        fn = torch.ops._C.mul_and_silu
51
52
    elif activation == "gelu":
        layer = GeluAndMul(approximate="none")
53
        fn = torch.ops._C.gelu_and_mul
54
55
    elif activation == "gelu_tanh":
        layer = GeluAndMul(approximate="tanh")
56
        fn = torch.ops._C.gelu_tanh_and_mul
57
58
59
60
    elif activation == "fatrelu":
        threshold = random.uniform(0, 1)
        layer = FatreluAndMul(threshold)
        fn = torch.ops._C.fatrelu_and_mul
61
    out = layer(x)
62
    ref_out = layer.forward_native(x)
63
64
65
    # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
    # equivalent to the native PyTorch implementations, so we can do exact
    # comparison.
66
    torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
69
70
    d = x.shape[-1] // 2
    output_shape = (x.shape[:-1] + (d, ))
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
71
72
73
74
    if activation == "fatrelu":
        opcheck(fn, (out, x, threshold))
    else:
        opcheck(fn, (out, x))
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76
77
78
79

@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
                                        (NewGELU, torch.ops._C.gelu_new),
                                        (QuickGELU, torch.ops._C.gelu_quick)])
80
81
82
83
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
84
@pytest.mark.parametrize("device", CUDA_DEVICES)
85
@torch.inference_mode()
86
87
def test_activation(
    activation: Type[torch.nn.Module],
88
89
90
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
91
    seed: int,
92
    device: str,
93
) -> None:
94
    current_platform.seed_everything(seed)
95
96
    torch.set_default_device(device)
    x = torch.randn(num_tokens, d, dtype=dtype)
97
98
    layer = activation[0]()
    fn = activation[1]
99
    out = layer(x)
100
    ref_out = layer.forward_native(x)
101
102
103
104
    torch.testing.assert_close(out,
                               ref_out,
                               atol=get_default_atol(out),
                               rtol=get_default_rtol(out))
105
106
107

    out = torch.empty_like(x)
    opcheck(fn, (out, x))