test_models.py 4.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
"""Compare the outputs of HF and vLLM when using greedy sampling.

4
Run `pytest tests/models/test_models.py`.
Woosuk Kwon's avatar
Woosuk Kwon committed
5
"""
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
import pytest
8
9
10
import torch

from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from ...registry import HF_EXAMPLE_MODELS
13
from ...utils import check_logprobs_close
14

15
16
17
18
19
20
# These have unsupported head_dim for FA. We do not
# not have a clean way to fall back, so we fail with
# a clear msg when it happens.
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

21
22
23
24
25
26
27
28
29
30
31
32
33
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "openbmb/MiniCPM3-4B",
    "Qwen/Qwen-7B",
    "Qwen/Qwen2.5-0.5B-Instruct",
    "ehristoforu/Falcon3-MoE-2x7B-Insruct",
]

Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
# @maybe_test_rocm_aiter
36
@pytest.mark.parametrize(
37
    "model_arch",
38
39
    [
        pytest.param(
40
            "BloomForCausalLM",  # testing alibi slopes
41
42
43
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
44
            "GPT2LMHeadModel",  # gpt2
45
46
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
47
48
49
        pytest.param("GPTJForCausalLM"),
        pytest.param("GPTBigCodeForCausalLM"),
        pytest.param("GPTNeoXForCausalLM"),
50
        pytest.param(
51
            "GemmaForCausalLM",  # gemma
52
53
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
54
        pytest.param("GlmForCausalLM"),
55
        pytest.param(
56
            "LlamaForCausalLM",
57
58
59
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
60
            "MiniCPM3ForCausalLM",
61
62
63
64
            # fused_moe not supported on CPU
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
65
            "OPTForCausalLM",
66
67
68
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
69
            "PhiForCausalLM",
70
71
            marks=[pytest.mark.core_model],
        ),
72
        pytest.param("QWenLMHeadModel", ),
73
        pytest.param(
74
            "Qwen2ForCausalLM",
75
76
            marks=[pytest.mark.core_model],
        ),
77
78
        pytest.param("StableLmForCausalLM"),
        pytest.param("Starcoder2ForCausalLM"),
79
        pytest.param(
80
            "MixtralForCausalLM",
81
82
            marks=[pytest.mark.cpu_model],
        )
83
84
    ])
@pytest.mark.parametrize("dtype", ["half"])
85
86
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
87
88
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
89
def test_models(hf_runner, vllm_runner, example_prompts, model_arch: str,
90
91
92
                dtype: str, max_tokens: int, num_logprobs: int,
                use_rocm_aiter: bool, monkeypatch) -> None:

93
94
    model = HF_EXAMPLE_MODELS.get_hf_info(model_arch).default

95
96
    if model in REQUIRES_V0:
        monkeypatch.setenv("VLLM_USE_V1", "0")
97

98
99
100
101
102
103
104
105
106
    if use_rocm_aiter and (model in AITER_MODEL_LIST):
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
    elif use_rocm_aiter and model not in AITER_MODEL_LIST:
        # Skip model that are not using AITER tests.
        # When more AITER kernels are added, this list will not be
        # needed as all the models will be calling AITER kernels
        # in parts of the operators
        pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

107
    with hf_runner(model, dtype=dtype) as hf_model:
108
109
110
111
        if model.startswith("THUDM/chatglm3"):
            hf_model.model.get_output_embeddings = lambda: \
                hf_model.model.transformer.output_layer

112
113
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
114

115
    with vllm_runner(model, dtype=dtype) as vllm_model:
116
117
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
118

119
    check_logprobs_close(
120
121
122
123
124
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
125
126
127
128
129
130
131
    if use_rocm_aiter:
        # this is to ensure that vllm engine
        # has deallocated the memory before running the next
        # unit tests. On ROCm, when using AITER
        # the memory might not be deallocated completely
        # before running the next test case
        torch.cuda.synchronize()