"vllm/entrypoints/openai/engine/protocol.py" did not exist on "0512c04aee408367a068b5960e7857c722ed204d"
vision_embedding_online.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# ruff: noqa: E501
4
5
6
"""Example Python client for multimodal embedding API using vLLM API server.

Refer to each `run_*` function for the command to run the server for that model.
7
"""
8

9
10
11
import argparse
import base64
import io
12
from typing import Literal
13

14
15
16
17
from openai import OpenAI
from openai._types import NOT_GIVEN, NotGiven
from openai.types.chat import ChatCompletionMessageParam
from openai.types.create_embedding_response import CreateEmbeddingResponse
18
from PIL import Image
19

20
21
22
23
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

24
25
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg"
text = "A cat standing in the snow."
26

27

28
29
30
31
32
def create_chat_embeddings(
    client: OpenAI,
    *,
    messages: list[ChatCompletionMessageParam],
    model: str,
33
    encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN,
34
35
    continue_final_message: bool = False,
    add_special_tokens: bool = False,
36
37
38
39
40
41
42
43
) -> CreateEmbeddingResponse:
    """
    Convenience function for accessing vLLM's Chat Embeddings API,
    which is an extension of OpenAI's existing Embeddings API.
    """
    return client.post(
        "/embeddings",
        cast_to=CreateEmbeddingResponse,
44
45
46
47
48
49
50
        body={
            "messages": messages,
            "model": model,
            "encoding_format": encoding_format,
            "continue_final_message": continue_final_message,
            "add_special_tokens": add_special_tokens,
        },
51
52
53
    )


54
55
56
57
58
def print_embeddings(embeds):
    embeds_trimmed = (str(embeds[:4])[:-1] + ", ...]") if len(embeds) > 4 else embeds
    print(f"Embeddings: {embeds_trimmed} (size={len(embeds)})")


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def run_clip(client: OpenAI, model: str):
    """
    Start the server using:

    vllm serve openai/clip-vit-base-patch32 \
        --runner pooling
    """

    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": image_url}},
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

    print("Image embedding output:", response.data[0].embedding)

    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "a photo of a cat"},
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

    print("Text embedding output:", response.data[0].embedding)


100
def run_dse_qwen2_vl(client: OpenAI, model: str):
101
102
103
    """
    Start the server using:

104
    vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
105
106
        --runner pooling \
        --trust-remote-code \
107
108
        --max-model-len 8192 \
        --chat-template examples/template_dse_qwen2_vl.jinja
109
    """
110
111
112
113
114
115
    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
116
117
118
119
120
121
122
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": image_url,
                        },
                    },
                    {"type": "text", "text": "What is shown in this image?"},
123
124
125
126
127
128
129
130
131
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

    print("Image embedding output:", response.data[0].embedding)

132
133
134
135
136
137
138
    # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
    # of the minimum input size
    buffer = io.BytesIO()
    image_placeholder = Image.new("RGB", (56, 56))
    image_placeholder.save(buffer, "png")
    buffer.seek(0)
    image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
139
140
141
142
143
144
145
    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
                    {
146
147
148
149
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image_placeholder}",
                        },
150
                    },
151
                    {"type": "text", "text": "Query: What is the weather like today?"},
152
153
154
155
156
157
158
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

159
160
161
    print("Text embedding output:", response.data[0].embedding)


162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def run_qwen3_vl(client: OpenAI, model: str):
    """
    Start the server using:

    vllm serve Qwen/Qwen3-VL-Embedding-2B \
        --runner pooling \
        --max-model-len 8192
    """

    default_instruction = "Represent the user's input."

    print("Text embedding output:")
    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": default_instruction},
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": text},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": ""},
                ],
            },
        ],
        model=model,
        encoding_format="float",
        continue_final_message=True,
        add_special_tokens=True,
    )
    print_embeddings(response.data[0].embedding)

    print("Image embedding output:")
    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": default_instruction},
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": image_url}},
                    {"type": "text", "text": ""},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": ""},
                ],
            },
        ],
        model=model,
        encoding_format="float",
        continue_final_message=True,
        add_special_tokens=True,
    )
    print_embeddings(response.data[0].embedding)

    print("Image+Text embedding output:")
    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": default_instruction},
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image_url", "image_url": {"url": image_url}},
                    {
                        "type": "text",
                        "text": f"{text}",
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": ""},
                ],
            },
        ],
        model=model,
        encoding_format="float",
        continue_final_message=True,
        add_special_tokens=True,
    )
    print_embeddings(response.data[0].embedding)


269
270
271
272
273
def run_siglip(client: OpenAI, model: str):
    """
    Start the server using:

    vllm serve google/siglip-base-patch16-224 \
274
275
        --runner pooling \
        --chat-template template_basic.jinja
276
    """
277
278
279
280
281
282
283

    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                    {"type": "image_url", "image_url": {"url": image_url}},
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

    print("Image embedding output:", response.data[0].embedding)

    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "a photo of a cat"},
301
302
303
304
305
                ],
            }
        ],
        model=model,
        encoding_format="float",
306
307
    )

308
    print("Text embedding output:", response.data[0].embedding)
309
310


311
def run_vlm2vec(client: OpenAI, model: str):
312
313
314
    """
    Start the server using:

315
    vllm serve TIGER-Lab/VLM2Vec-Full \
316
317
        --runner pooling \
        --trust-remote-code \
318
319
        --max-model-len 4096 \
        --chat-template examples/template_vlm2vec_phi3v.jinja
320
    """
321

322
323
324
    response = create_chat_embeddings(
        client,
        messages=[
325
326
327
            {
                "role": "user",
                "content": [
328
329
                    {"type": "image_url", "image_url": {"url": image_url}},
                    {"type": "text", "text": "Represent the given image."},
330
331
                ],
            }
332
333
334
335
336
        ],
        model=model,
        encoding_format="float",
    )

337
338
    print("Image embedding output:")
    print_embeddings(response.data[0].embedding)
339
340
341
342

    response = create_chat_embeddings(
        client,
        messages=[
343
344
345
            {
                "role": "user",
                "content": [
346
                    {"type": "image_url", "image_url": {"url": image_url}},
347
                    {
348
349
                        "type": "text",
                        "text": "Represent the given image with the following question: What is in the image.",
350
                    },
351
352
353
354
355
356
357
                ],
            }
        ],
        model=model,
        encoding_format="float",
    )

358
359
    print("Image+Text embedding output:")
    print_embeddings(response.data[0].embedding)
360
361
362
363
364
365
366
367

    response = create_chat_embeddings(
        client,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "A cat and a dog"},
368
369
                ],
            }
370
371
372
        ],
        model=model,
        encoding_format="float",
373
374
    )

375
376
    print("Text embedding output:")
    print_embeddings(response.data[0].embedding)
377
378
379


model_example_map = {
380
    "clip": run_clip,
381
    "qwen3_vl": run_qwen3_vl,
382
    "dse_qwen2_vl": run_dse_qwen2_vl,
383
384
    "siglip": run_siglip,
    "vlm2vec": run_vlm2vec,
385
}
386
387


388
def parse_args():
389
390
    parser = argparse.ArgumentParser(
        "Script to call a specified VLM through the API. Make sure to serve "
391
        "the model with `--runner pooling` before running this."
392
393
394
395
    )
    parser.add_argument(
        "--model",
        type=str,
396
        choices=model_example_map.keys(),
397
        required=True,
398
        help="The name of the embedding model.",
399
    )
400
401
    return parser.parse_args()

402

403
def main(args):
404
405
406
407
408
409
410
411
412
413
    client = OpenAI(
        # defaults to os.environ.get("OPENAI_API_KEY")
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    models = client.models.list()
    model_id = models.data[0].id

    model_example_map[args.model](client, model_id)
414
415


416
if __name__ == "__main__":
417
418
    args = parse_args()
    main(args)