test_siglip.py 4.46 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import Any

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import pytest
from transformers import SiglipModel

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_embeddings_close

HF_TEXT_PROMPTS = [
    "a photo of a stop sign",
    "a photo of a cherry blossom",
]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
    {
        "stop_sign": "",
        "cherry_blossom": "",
    }
)

24
25
26
27
28
29
MODELS = [
    "google/siglip-base-patch16-224",
    "google/siglip2-base-patch16-224",
    # Different image embedding dim than text_config.hidden_size
    "google/siglip2-giant-opt-patch16-384",
]
30
31
32
33
34
35
36
37
38
39


def _run_test(
    hf_runner: type[HfRunner],
    vllm_runner: type[VllmRunner],
    input_texts: list[str],
    input_images: PromptImageInput,
    model: str,
    *,
    dtype: str,
40
    tokenization_kwargs: dict[str, Any] | None = None,
41
    attention_config: dict[str, Any] | None = None,
42
) -> None:
43
44
45
    if tokenization_kwargs is None:
        tokenization_kwargs = {}

46
    with vllm_runner(
47
48
49
50
51
52
        model,
        runner="pooling",
        dtype=dtype,
        enforce_eager=True,
        max_model_len=64,
        gpu_memory_utilization=0.7,
53
        attention_config=attention_config,
54
    ) as vllm_model:
55
56
57
        vllm_outputs = vllm_model.embed(
            input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
        )
58
59

    with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
60
61
62
        all_inputs = hf_model.get_inputs(
            input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
        )
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        all_outputs = []
        for inputs in all_inputs:
            inputs = hf_model.wrap_device(inputs)

            if "pixel_values" in inputs:
                pooled_output = hf_model.model.get_image_features(
                    pixel_values=inputs.pixel_values,
                ).squeeze(0)
            else:
                pooled_output = hf_model.model.get_text_features(
                    input_ids=inputs.input_ids,
                ).squeeze(0)

            all_outputs.append(pooled_output.tolist())

        hf_outputs = all_outputs

    check_embeddings_close(
        embeddings_0_lst=hf_outputs,
        embeddings_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text(
    hf_runner,
    vllm_runner,
    image_assets,
95
    siglip_attention_config,
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    model: str,
    dtype: str,
) -> None:
    input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
    input_texts = [text for text, _ in input_texts_images]
    input_images = [image for _, image in input_texts_images]

    _run_test(
        hf_runner,
        vllm_runner,
        input_texts,
        input_images,  # type: ignore
        model,
        dtype=dtype,
110
111
112
113
        tokenization_kwargs={
            "padding": "max_length",
            "max_length": 64,
        },  # siglip2 was trained with this padding setting.
114
        attention_config=siglip_attention_config,
115
116
117
118
119
120
121
122
123
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_image(
    hf_runner,
    vllm_runner,
    image_assets,
124
    siglip_attention_config,
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    model: str,
    dtype: str,
) -> None:
    input_texts_images = [
        (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
    ]
    input_texts = [text for text, _ in input_texts_images]
    input_images = [image for _, image in input_texts_images]

    _run_test(
        hf_runner,
        vllm_runner,
        input_texts,
        input_images,
        model,
        dtype=dtype,
141
        attention_config=siglip_attention_config,
142
143
144
145
146
147
148
149
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text_image_no_crash(
    vllm_runner,
    image_assets,
150
    siglip_attention_config,
151
152
153
154
155
156
157
158
159
160
161
162
    model: str,
    dtype: str,
) -> None:
    texts = [HF_TEXT_PROMPTS[0]]
    images = [image_assets[0].pil_image]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        enforce_eager=True,
        max_model_len=64,
163
        gpu_memory_utilization=0.7,
164
        attention_config=siglip_attention_config,
165
166
167
168
169
170
    ) as vllm_model:
        with pytest.raises(ValueError, match="not both"):
            vllm_model.embed(texts, images=images)

        vllm_model.embed(texts)
        vllm_model.embed([""], images=images)