test_activation.py 3.61 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import random
4
5
from typing import Type

6
import pytest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
import torch
8

9
from tests.kernels.utils import opcheck
10
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
11
12
13
                                                   GeluAndMul, MulAndSilu,
                                                   NewGELU, QuickGELU,
                                                   SiluAndMul)
14
from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17
from .allclose_default import get_default_atol, get_default_rtol

18
19
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
20
D = [512, 13824]  # Arbitrary values for testing
21
SEEDS = [0]
22
23
24
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
@pytest.mark.parametrize(
    "activation",
    ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
30
31
32
33
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
34
@pytest.mark.parametrize("device", CUDA_DEVICES)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
@torch.inference_mode()
36
def test_act_and_mul(
37
    activation: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
41
    seed: int,
42
    device: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
43
) -> None:
44
    current_platform.seed_everything(seed)
45
46
    torch.set_default_device(device)
    x = torch.randn(num_tokens, 2 * d, dtype=dtype)
47
    if activation == "silu_and_mul":
48
        layer = SiluAndMul()
49
        fn = torch.ops._C.silu_and_mul
50
51
52
    if activation == "mul_and_silu":
        layer = MulAndSilu()
        fn = torch.ops._C.mul_and_silu
53
54
    elif activation == "gelu":
        layer = GeluAndMul(approximate="none")
55
        fn = torch.ops._C.gelu_and_mul
56
57
    elif activation == "gelu_tanh":
        layer = GeluAndMul(approximate="tanh")
58
        fn = torch.ops._C.gelu_tanh_and_mul
59
60
61
62
    elif activation == "fatrelu":
        threshold = random.uniform(0, 1)
        layer = FatreluAndMul(threshold)
        fn = torch.ops._C.fatrelu_and_mul
63
    out = layer(x)
64
    ref_out = layer.forward_native(x)
65
66
67
    # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
    # equivalent to the native PyTorch implementations, so we can do exact
    # comparison.
68
    torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
Woosuk Kwon's avatar
Woosuk Kwon committed
69

70
71
72
    d = x.shape[-1] // 2
    output_shape = (x.shape[:-1] + (d, ))
    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
73
74
75
76
    if activation == "fatrelu":
        opcheck(fn, (out, x, threshold))
    else:
        opcheck(fn, (out, x))
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78
79
80
81

@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
                                        (NewGELU, torch.ops._C.gelu_new),
                                        (QuickGELU, torch.ops._C.gelu_quick)])
82
83
84
85
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
86
@pytest.mark.parametrize("device", CUDA_DEVICES)
87
@torch.inference_mode()
88
89
def test_activation(
    activation: Type[torch.nn.Module],
90
91
92
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
93
    seed: int,
94
    device: str,
95
) -> None:
96
    current_platform.seed_everything(seed)
97
98
    torch.set_default_device(device)
    x = torch.randn(num_tokens, d, dtype=dtype)
99
100
    layer = activation[0]()
    fn = activation[1]
101
    out = layer(x)
102
    ref_out = layer.forward_native(x)
103
104
105
106
    torch.testing.assert_close(out,
                               ref_out,
                               atol=get_default_atol(out),
                               rtol=get_default_rtol(out))
107
108
109

    out = torch.empty_like(x)
    opcheck(fn, (out, x))