"tests/models/test_paligemma.py" did not exist on "98d6682cd1f27fa48bf915d3fd3e1eb1ee3014c4"
test_common.py 6.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import Optional

Woosuk Kwon's avatar
Woosuk Kwon committed
5
import pytest
6
7
8
import torch

from vllm.platforms import current_platform
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
from ....utils import large_gpu_mark
11
from ...registry import HF_EXAMPLE_MODELS
12
from ...utils import check_logprobs_close
13

14
# These have unsupported head_dim for FA. We do not
15
# have a clean way to fall back, so we fail with
16
17
18
19
# a clear msg when it happens.
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

20
21
22
23
24
25
26
27
# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "openbmb/MiniCPM3-4B",
28
    "Qwen/Qwen-7B-Chat",
29
    "Qwen/Qwen2.5-0.5B-Instruct",
30
    "TitanML/tiny-mixtral",
31
    "Qwen/Qwen3-8B",
32
33
]

Woosuk Kwon's avatar
Woosuk Kwon committed
34

35
# @maybe_test_rocm_aiter
36
@pytest.mark.parametrize(
37
    "model",
38
39
    [
        pytest.param(
40
            "bigscience/bloom-560m",  # bloom - testing alibi slopes
41
            marks=[pytest.mark.core_model, pytest.mark.slow_test],
42
43
        ),
        pytest.param(
44
            "openai-community/gpt2",  # gpt2
45
46
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
47
48
49
        pytest.param("Milos/slovak-gpt-j-405M"),  # gptj
        pytest.param("bigcode/tiny_starcoder_py"),  # gpt_bigcode
        pytest.param("EleutherAI/pythia-70m"),  # gpt_neox
50
        pytest.param(
51
            "google/gemma-1.1-2b-it",  # gemma
52
53
54
55
            marks=[
                pytest.mark.core_model, pytest.mark.cpu_model,
                pytest.mark.slow_test
            ],
56
        ),
57
        pytest.param(
58
            "zai-org/chatglm3-6b",  # chatglm (text-only)
59
60
61
        ),
        pytest.param(
            "meta-llama/Llama-3.2-1B-Instruct",  # llama
62
63
64
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
65
            "openbmb/MiniCPM3-4B",
66
            # fused_moe not supported on CPU
67
68
            marks=[pytest.mark.core_model,
                   large_gpu_mark(min_gb=32)],
69
70
        ),
        pytest.param(
71
            "facebook/opt-125m",  # opt
72
73
74
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
75
            "microsoft/phi-2",  # phi
76
            marks=[pytest.mark.core_model, pytest.mark.slow_test],
77
        ),
78
        pytest.param(
79
80
81
82
            "Qwen/Qwen-7B-Chat",  # qwen (text-only)
        ),
        pytest.param(
            "Qwen/Qwen2.5-0.5B-Instruct",  # qwen2
83
84
85
86
            marks=[
                pytest.mark.core_model, pytest.mark.cpu_model,
                pytest.mark.slow_test
            ],
87
        ),
88
89
90
        pytest.param(
            "Qwen/Qwen3-8B",  # qwen (text-only)
        ),
91
92
        pytest.param("stabilityai/stablelm-3b-4e1t"),  # stablelm
        pytest.param("bigcode/starcoder2-3b"),  # starcoder2
93
        pytest.param(
94
            "TitanML/tiny-mixtral",  # mixtral
95
96
97
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
98
            "allenai/OLMoE-1B-7B-0924-Instruct",
99
            marks=[pytest.mark.cpu_model],
100
        ),
101
        pytest.param("swiss-ai/Apertus-8B-2509"),  # apertus
102
    ])
103
104
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
105
106
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
107
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
108
109
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                max_tokens: int, num_logprobs: int, use_rocm_aiter: bool,
110
                use_prompt_embeds: bool, monkeypatch) -> None:
111

112
113
114
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
115

116
117
    if model in REQUIRES_V0:
        monkeypatch.setenv("VLLM_USE_V1", "0")
118

119
120
121
122
123
124
125
126
127
    if use_rocm_aiter and (model in AITER_MODEL_LIST):
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
    elif use_rocm_aiter and model not in AITER_MODEL_LIST:
        # Skip model that are not using AITER tests.
        # When more AITER kernels are added, this list will not be
        # needed as all the models will be calling AITER kernels
        # in parts of the operators
        pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

128
129
130
131
132
133
    # Note: can be removed when
    # https://github.com/vllm-project/vllm/pull/24278 finished
    if current_platform.is_cpu() and use_prompt_embeds:
        pytest.skip("Skipping use_prompt_embeds=True with "
                    "V1-only CPU backend.")

134
    with hf_runner(model) as hf_model:
135
136
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
137

138
139
140
        prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
                                                       else None)

141
142
143
144
145
146
147
148
149
150
        prompt_token_ids = []
        for prompt in example_prompts:
            token_ids = hf_model.tokenizer(prompt,
                                           return_tensors="pt").input_ids.to(
                                               hf_model.model.device)
            prompt_token_ids.append(token_ids)
            if prompt_embeds is not None:
                prompt_embeds.append(hf_model.model.get_input_embeddings()(
                    token_ids).squeeze(0))

151
152
153
154
155
156
    with vllm_runner(
            model,
            tokenizer_name=model_info.tokenizer or model,
            tokenizer_mode=model_info.tokenizer_mode,
            trust_remote_code=model_info.trust_remote_code,
            max_num_seqs=2,
157
            enable_prompt_embeds=use_prompt_embeds,
158
    ) as vllm_model:
159
160
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
161
162
163
        if prompt_embeds is not None:
            vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
                prompt_embeds, max_tokens, num_logprobs)
164

165
    check_logprobs_close(
166
167
168
169
170
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
171
172
173
174
175
176
177
178
    if prompt_embeds is not None:
        check_logprobs_close(
            outputs_0_lst=vllm_outputs,
            outputs_1_lst=vllm_outputs_from_embeds,
            name_0="vllm",
            name_1="vllm_from_embeds",
        )

179
180
181
182
183
184
185
    if use_rocm_aiter:
        # this is to ensure that vllm engine
        # has deallocated the memory before running the next
        # unit tests. On ROCm, when using AITER
        # the memory might not be deallocated completely
        # before running the next test case
        torch.cuda.synchronize()