vision_language_pooling.py 8.25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Cyrus Leung's avatar
Cyrus Leung committed
3
4
"""
This example shows how to use vLLM for running offline inference with
5
the correct prompt format on vision language models for multimodal pooling.
Cyrus Leung's avatar
Cyrus Leung committed
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

Cyrus Leung's avatar
Cyrus Leung committed
11
from argparse import Namespace
12
from dataclasses import asdict
Cyrus Leung's avatar
Cyrus Leung committed
13
14
15
16
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args

from PIL.Image import Image

17
from vllm import LLM, EngineArgs
18
from vllm.entrypoints.score_utils import ScoreMultiModalParam
Cyrus Leung's avatar
Cyrus Leung committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser


class TextQuery(TypedDict):
    modality: Literal["text"]
    text: str


class ImageQuery(TypedDict):
    modality: Literal["image"]
    image: Image


class TextImageQuery(TypedDict):
    modality: Literal["text+image"]
    text: str
    image: Image


39
40
41
42
43
44
45
46
class TextImagesQuery(TypedDict):
    modality: Literal["text+images"]
    text: str
    image: ScoreMultiModalParam


QueryModality = Literal["text", "image", "text+image", "text+images"]
Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery]
Cyrus Leung's avatar
Cyrus Leung committed
47
48
49


class ModelRequestData(NamedTuple):
50
    engine_args: EngineArgs
51
52
53
54
    prompt: Optional[str] = None
    image: Optional[Image] = None
    query: Optional[str] = None
    documents: Optional[ScoreMultiModalParam] = None
Cyrus Leung's avatar
Cyrus Leung committed
55
56


57
def run_e5_v(query: Query) -> ModelRequestData:
58
    llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n"  # noqa: E501
Cyrus Leung's avatar
Cyrus Leung committed
59
60
61

    if query["modality"] == "text":
        text = query["text"]
62
        prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
Cyrus Leung's avatar
Cyrus Leung committed
63
64
        image = None
    elif query["modality"] == "image":
65
        prompt = llama3_template.format("<image>\nSummary above image in one word: ")
Cyrus Leung's avatar
Cyrus Leung committed
66
67
        image = query["image"]
    else:
68
        modality = query["modality"]
Cyrus Leung's avatar
Cyrus Leung committed
69
70
        raise ValueError(f"Unsupported query modality: '{modality}'")

71
    engine_args = EngineArgs(
Cyrus Leung's avatar
Cyrus Leung committed
72
        model="royokong/e5-v",
73
        runner="pooling",
Cyrus Leung's avatar
Cyrus Leung committed
74
        max_model_len=4096,
75
        limit_mm_per_prompt={"image": 1},
Cyrus Leung's avatar
Cyrus Leung committed
76
77
78
    )

    return ModelRequestData(
79
        engine_args=engine_args,
Cyrus Leung's avatar
Cyrus Leung committed
80
81
82
83
84
        prompt=prompt,
        image=image,
    )


85
def run_vlm2vec(query: Query) -> ModelRequestData:
Cyrus Leung's avatar
Cyrus Leung committed
86
87
88
89
90
91
92
93
94
    if query["modality"] == "text":
        text = query["text"]
        prompt = f"Find me an everyday image that matches the given caption: {text}"  # noqa: E501
        image = None
    elif query["modality"] == "image":
        prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image."  # noqa: E501
        image = query["image"]
    elif query["modality"] == "text+image":
        text = query["text"]
95
96
97
        prompt = (
            f"<|image_1|> Represent the given image with the following question: {text}"  # noqa: E501
        )
Cyrus Leung's avatar
Cyrus Leung committed
98
99
        image = query["image"]
    else:
100
        modality = query["modality"]
Cyrus Leung's avatar
Cyrus Leung committed
101
102
        raise ValueError(f"Unsupported query modality: '{modality}'")

103
    engine_args = EngineArgs(
Cyrus Leung's avatar
Cyrus Leung committed
104
        model="TIGER-Lab/VLM2Vec-Full",
105
        runner="pooling",
106
        max_model_len=4096,
Cyrus Leung's avatar
Cyrus Leung committed
107
108
        trust_remote_code=True,
        mm_processor_kwargs={"num_crops": 4},
109
        limit_mm_per_prompt={"image": 1},
Cyrus Leung's avatar
Cyrus Leung committed
110
111
112
    )

    return ModelRequestData(
113
        engine_args=engine_args,
Cyrus Leung's avatar
Cyrus Leung committed
114
115
116
117
118
        prompt=prompt,
        image=image,
    )


119
120
121
122
123
124
def run_jinavl_reranker(query: Query) -> ModelRequestData:
    if query["modality"] != "text+images":
        raise ValueError(f"Unsupported query modality: '{query['modality']}'")

    engine_args = EngineArgs(
        model="jinaai/jina-reranker-m0",
125
        runner="pooling",
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        max_model_len=32768,
        trust_remote_code=True,
        mm_processor_kwargs={
            "min_pixels": 3136,
            "max_pixels": 602112,
        },
        limit_mm_per_prompt={"image": 1},
    )

    return ModelRequestData(
        engine_args=engine_args,
        query=query["text"],
        documents=query["image"],
    )


Cyrus Leung's avatar
Cyrus Leung committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def get_query(modality: QueryModality):
    if modality == "text":
        return TextQuery(modality="text", text="A dog sitting in the grass")

    if modality == "image":
        return ImageQuery(
            modality="image",
            image=fetch_image(
                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg"  # noqa: E501
            ),
        )

    if modality == "text+image":
        return TextImageQuery(
            modality="text+image",
            text="A cat standing in the snow.",
            image=fetch_image(
                "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg"  # noqa: E501
            ),
        )

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    if modality == "text+images":
        return TextImagesQuery(
            modality="text+images",
            text="slm markdown",
            image={
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
                        },
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
                        },
                    },
                ]
            },
        )

Cyrus Leung's avatar
Cyrus Leung committed
185
186
187
188
    msg = f"Modality {modality} is not supported."
    raise ValueError(msg)


189
def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
Cyrus Leung's avatar
Cyrus Leung committed
190
191
192
    query = get_query(modality)
    req_data = model_example_map[model](query)

193
194
195
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
196
197
        req_data.engine_args.limit_mm_per_prompt or {}
    )
198

199
200
201
    engine_args = asdict(req_data.engine_args) | {"seed": seed}
    llm = LLM(**engine_args)

Cyrus Leung's avatar
Cyrus Leung committed
202
203
204
205
    mm_data = {}
    if req_data.image is not None:
        mm_data["image"] = req_data.image

206
207
208
209
210
211
    outputs = llm.embed(
        {
            "prompt": req_data.prompt,
            "multi_modal_data": mm_data,
        }
    )
Cyrus Leung's avatar
Cyrus Leung committed
212

213
    print("-" * 50)
Cyrus Leung's avatar
Cyrus Leung committed
214
215
    for output in outputs:
        print(output.outputs.embedding)
216
        print("-" * 50)
Cyrus Leung's avatar
Cyrus Leung committed
217
218


219
220
221
222
223
224
225
226
227
228
229
230
231
232
def run_score(model: str, modality: QueryModality, seed: Optional[int]):
    query = get_query(modality)
    req_data = model_example_map[model](query)

    engine_args = asdict(req_data.engine_args) | {"seed": seed}
    llm = LLM(**engine_args)

    outputs = llm.score(req_data.query, req_data.documents)

    print("-" * 30)
    print([output.outputs.score for output in outputs])
    print("-" * 30)


Cyrus Leung's avatar
Cyrus Leung committed
233
234
235
model_example_map = {
    "e5_v": run_e5_v,
    "vlm2vec": run_vlm2vec,
236
    "jinavl_reranker": run_jinavl_reranker,
Cyrus Leung's avatar
Cyrus Leung committed
237
238
}

239
240

def parse_args():
Cyrus Leung's avatar
Cyrus Leung committed
241
    parser = FlexibleArgumentParser(
242
        description="Demo on using vLLM for offline inference with "
243
        "vision language models for multimodal pooling tasks."
244
245
246
247
248
249
250
251
252
    )
    parser.add_argument(
        "--model-name",
        "-m",
        type=str,
        default="vlm2vec",
        choices=model_example_map.keys(),
        help="The name of the embedding model.",
    )
253
254
255
256
257
258
259
260
    parser.add_argument(
        "--task",
        "-t",
        type=str,
        default="embedding",
        choices=["embedding", "scoring"],
        help="The task type.",
    )
261
262
263
264
265
266
267
268
269
270
271
272
273
    parser.add_argument(
        "--modality",
        type=str,
        default="image",
        choices=get_args(QueryModality),
        help="Modality of the input.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
274
    return parser.parse_args()
275

276
277

def main(args: Namespace):
278
279
280
281
282
283
    if args.task == "embedding":
        run_encode(args.model_name, args.modality, args.seed)
    elif args.task == "scoring":
        run_score(args.model_name, args.modality, args.seed)
    else:
        raise ValueError(f"Unsupported task: {args.task}")
284
285
286
287


if __name__ == "__main__":
    args = parse_args()
Cyrus Leung's avatar
Cyrus Leung committed
288
    main(args)