test_models.py 4.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
"""Compare the outputs of HF and vLLM when using greedy sampling.

4
Run `pytest tests/models/test_models.py`.
Woosuk Kwon's avatar
Woosuk Kwon committed
5
"""
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
import pytest
8
9
10
import torch

from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
from ....utils import large_gpu_mark
13
from ...registry import HF_EXAMPLE_MODELS
14
from ...utils import check_logprobs_close
15

16
17
18
19
20
21
# These have unsupported head_dim for FA. We do not
# not have a clean way to fall back, so we fail with
# a clear msg when it happens.
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

22
23
24
25
26
27
28
29
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "openbmb/MiniCPM3-4B",
30
    "Qwen/Qwen-7B-Chat",
31
    "Qwen/Qwen2.5-0.5B-Instruct",
32
    "TitanML/tiny-mixtral",
33
34
]

Woosuk Kwon's avatar
Woosuk Kwon committed
35

36
# @maybe_test_rocm_aiter
37
@pytest.mark.parametrize(
38
    "model",
39
40
    [
        pytest.param(
41
            "bigscience/bloom-560m",  # bloom - testing alibi slopes
42
43
44
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
45
            "openai-community/gpt2",  # gpt2
46
47
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
48
49
50
        pytest.param("Milos/slovak-gpt-j-405M"),  # gptj
        pytest.param("bigcode/tiny_starcoder_py"),  # gpt_bigcode
        pytest.param("EleutherAI/pythia-70m"),  # gpt_neox
51
        pytest.param(
52
            "google/gemma-1.1-2b-it",  # gemma
53
54
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
55
        pytest.param(
56
57
58
59
            "THUDM/chatglm3-6b",  # chatglm (text-only)
        ),
        pytest.param(
            "meta-llama/Llama-3.2-1B-Instruct",  # llama
60
61
62
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
63
            "openbmb/MiniCPM3-4B",
64
            # fused_moe not supported on CPU
65
66
            marks=[pytest.mark.core_model,
                   large_gpu_mark(min_gb=32)],
67
68
        ),
        pytest.param(
69
            "facebook/opt-125m",  # opt
70
71
72
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
73
            "microsoft/phi-2",  # phi
74
75
            marks=[pytest.mark.core_model],
        ),
76
        pytest.param(
77
78
79
80
            "Qwen/Qwen-7B-Chat",  # qwen (text-only)
        ),
        pytest.param(
            "Qwen/Qwen2.5-0.5B-Instruct",  # qwen2
81
82
            marks=[pytest.mark.core_model],
        ),
83
84
        pytest.param("stabilityai/stablelm-3b-4e1t"),  # stablelm
        pytest.param("bigcode/starcoder2-3b"),  # starcoder2
85
        pytest.param(
86
87
            "TitanML/tiny-mixtral",  # mixtral
            marks=[pytest.mark.cpu_model],
88
        )
89
    ])
90
91
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
92
93
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
94
95
96
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                max_tokens: int, num_logprobs: int, use_rocm_aiter: bool,
                monkeypatch) -> None:
97

98
99
100
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
101

102
103
    if model in REQUIRES_V0:
        monkeypatch.setenv("VLLM_USE_V1", "0")
104

105
106
107
108
109
110
111
112
113
    if use_rocm_aiter and (model in AITER_MODEL_LIST):
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
    elif use_rocm_aiter and model not in AITER_MODEL_LIST:
        # Skip model that are not using AITER tests.
        # When more AITER kernels are added, this list will not be
        # needed as all the models will be calling AITER kernels
        # in parts of the operators
        pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

114
    with hf_runner(model) as hf_model:
115
116
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
117

118
119
120
121
122
123
124
    with vllm_runner(
            model,
            tokenizer_name=model_info.tokenizer or model,
            tokenizer_mode=model_info.tokenizer_mode,
            trust_remote_code=model_info.trust_remote_code,
            max_num_seqs=2,
    ) as vllm_model:
125
126
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
127

128
    check_logprobs_close(
129
130
131
132
133
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
134
135
136
137
138
139
140
    if use_rocm_aiter:
        # this is to ensure that vllm engine
        # has deallocated the memory before running the next
        # unit tests. On ROCm, when using AITER
        # the memory might not be deallocated completely
        # before running the next test case
        torch.cuda.synchronize()