test_siglip.py 4.62 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import Any

6
import pytest
7
import torch
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import SiglipModel

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_embeddings_close

HF_TEXT_PROMPTS = [
    "a photo of a stop sign",
    "a photo of a cherry blossom",
]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
    {
        "stop_sign": "",
        "cherry_blossom": "",
    }
)

25
26
27
28
29
30
MODELS = [
    "google/siglip-base-patch16-224",
    "google/siglip2-base-patch16-224",
    # Different image embedding dim than text_config.hidden_size
    "google/siglip2-giant-opt-patch16-384",
]
31
32
33
34
35
36
37
38
39
40


def _run_test(
    hf_runner: type[HfRunner],
    vllm_runner: type[VllmRunner],
    input_texts: list[str],
    input_images: PromptImageInput,
    model: str,
    *,
    dtype: str,
41
    tokenization_kwargs: dict[str, Any] | None = None,
42
    attention_config: dict[str, Any] | None = None,
43
) -> None:
44
45
46
    if tokenization_kwargs is None:
        tokenization_kwargs = {}

47
    with vllm_runner(
48
49
50
51
52
53
        model,
        runner="pooling",
        dtype=dtype,
        enforce_eager=True,
        max_model_len=64,
        gpu_memory_utilization=0.7,
54
        attention_config=attention_config,
55
    ) as vllm_model:
56
57
58
        vllm_outputs = vllm_model.embed(
            input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
        )
59
60

    with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
61
62
63
        all_inputs = hf_model.get_inputs(
            input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs
        )
64
65
66
67
68
69
70
71

        all_outputs = []
        for inputs in all_inputs:
            inputs = hf_model.wrap_device(inputs)

            if "pixel_values" in inputs:
                pooled_output = hf_model.model.get_image_features(
                    pixel_values=inputs.pixel_values,
72
                )
73
74
75
            else:
                pooled_output = hf_model.model.get_text_features(
                    input_ids=inputs.input_ids,
76
                )
77

78
79
80
            if not isinstance(pooled_output, torch.Tensor):
                pooled_output = pooled_output.pooler_output
            pooled_output = pooled_output.squeeze(0)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            all_outputs.append(pooled_output.tolist())

        hf_outputs = all_outputs

    check_embeddings_close(
        embeddings_0_lst=hf_outputs,
        embeddings_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text(
    hf_runner,
    vllm_runner,
    image_assets,
99
    siglip_attention_config,
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    model: str,
    dtype: str,
) -> None:
    input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
    input_texts = [text for text, _ in input_texts_images]
    input_images = [image for _, image in input_texts_images]

    _run_test(
        hf_runner,
        vllm_runner,
        input_texts,
        input_images,  # type: ignore
        model,
        dtype=dtype,
114
115
116
117
        tokenization_kwargs={
            "padding": "max_length",
            "max_length": 64,
        },  # siglip2 was trained with this padding setting.
118
        attention_config=siglip_attention_config,
119
120
121
122
123
124
125
126
127
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_image(
    hf_runner,
    vllm_runner,
    image_assets,
128
    siglip_attention_config,
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    model: str,
    dtype: str,
) -> None:
    input_texts_images = [
        (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
    ]
    input_texts = [text for text, _ in input_texts_images]
    input_images = [image for _, image in input_texts_images]

    _run_test(
        hf_runner,
        vllm_runner,
        input_texts,
        input_images,
        model,
        dtype=dtype,
145
        attention_config=siglip_attention_config,
146
147
148
149
150
151
152
153
    )


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text_image_no_crash(
    vllm_runner,
    image_assets,
154
    siglip_attention_config,
155
156
157
158
159
160
161
162
163
164
165
166
    model: str,
    dtype: str,
) -> None:
    texts = [HF_TEXT_PROMPTS[0]]
    images = [image_assets[0].pil_image]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        enforce_eager=True,
        max_model_len=64,
167
        gpu_memory_utilization=0.7,
168
        attention_config=siglip_attention_config,
169
170
171
172
173
174
    ) as vllm_model:
        with pytest.raises(ValueError, match="not both"):
            vllm_model.embed(texts, images=images)

        vllm_model.embed(texts)
        vllm_model.embed([""], images=images)