offline_inference_vision_language.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
This example shows how to use vLLM for running offline inference 
with the correct prompt format on vision language models.

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
12
from vllm.assets.video import VideoAsset
13
14
15
16
from vllm.utils import FlexibleArgumentParser


# LLaVA-1.5
17
18
def run_llava(question, modality):
    assert modality == "image"
19
20
21
22

    prompt = f"USER: <image>\n{question}\nASSISTANT:"

    llm = LLM(model="llava-hf/llava-1.5-7b-hf")
23
24
    stop_token_ids = None
    return llm, prompt, stop_token_ids
25
26
27


# LLaVA-1.6/LLaVA-NeXT
28
29
def run_llava_next(question, modality):
    assert modality == "image"
30
31

    prompt = f"[INST] <image>\n{question} [/INST]"
32
33
34
35
36
37
38
    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
    stop_token_ids = None
    return llm, prompt, stop_token_ids


# LlaVA-NeXT-Video
# Currently only support for video input
39
40
41
def run_llava_next_video(question, modality):
    assert modality == "video"

42
43
    prompt = f"USER: <video>\n{question} ASSISTANT:"
    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
44
45
    stop_token_ids = None
    return llm, prompt, stop_token_ids
46
47


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# LLaVA-OneVision
def run_llava_onevision(question, modality):

    if modality == "video":
        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n"

    elif modality == "image":
        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n"

    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
              max_model_len=32768)
    stop_token_ids = None
    return llm, prompt, stop_token_ids


65
# Fuyu
66
67
def run_fuyu(question, modality):
    assert modality == "image"
68
69
70

    prompt = f"{question}\n"
    llm = LLM(model="adept/fuyu-8b")
71
72
    stop_token_ids = None
    return llm, prompt, stop_token_ids
73
74
75


# Phi-3-Vision
76
77
def run_phi3v(question, modality):
    assert modality == "image"
78
79
80
81
82
83
84
85
86
87
88
89
90

    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (128k) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # In this example, we override max_num_seqs to 5 while
    # keeping the original context length of 128k.
    llm = LLM(
        model="microsoft/Phi-3-vision-128k-instruct",
        trust_remote_code=True,
        max_num_seqs=5,
    )
91
92
    stop_token_ids = None
    return llm, prompt, stop_token_ids
93
94
95


# PaliGemma
96
97
def run_paligemma(question, modality):
    assert modality == "image"
98

99
100
    # PaliGemma has special prompt format for VQA
    prompt = "caption en"
101
    llm = LLM(model="google/paligemma-3b-mix-224")
102
103
    stop_token_ids = None
    return llm, prompt, stop_token_ids
104
105
106


# Chameleon
107
108
def run_chameleon(question, modality):
    assert modality == "image"
109
110
111

    prompt = f"{question}<image>"
    llm = LLM(model="facebook/chameleon-7b")
112
113
    stop_token_ids = None
    return llm, prompt, stop_token_ids
114
115
116


# MiniCPM-V
117
118
def run_minicpmv(question, modality):
    assert modality == "image"
119
120
121
122
123
124
125

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
126
127
128
129
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

    #2.6
    model_name = "openbmb/MiniCPM-V-2_6"
130
131
132
133
134
135
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    llm = LLM(
        model=model_name,
        trust_remote_code=True,
    )
136
137
138
139
140
141
142
143
144
145
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

    # 2.6
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
146
147
148
149
150
151
152
153

    messages = [{
        'role': 'user',
        'content': f'(<image>./</image>)\n{question}'
    }]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)
154
    return llm, prompt, stop_token_ids
155
156


157
# InternVL
158
159
160
def run_internvl(question, modality):
    assert modality == "image"

161
162
    model_name = "OpenGVLab/InternVL2-2B"

163
    llm = LLM(
164
        model=model_name,
165
166
167
        trust_remote_code=True,
        max_num_seqs=5,
    )
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
182
    return llm, prompt, stop_token_ids
183
184


185
# BLIP-2
186
187
def run_blip2(question, modality):
    assert modality == "image"
188
189
190
191
192

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
    prompt = f"Question: {question} Answer:"
    llm = LLM(model="Salesforce/blip2-opt-2.7b")
193
194
    stop_token_ids = None
    return llm, prompt, stop_token_ids
195
196


197
# Qwen
198
199
def run_qwen_vl(question, modality):
    assert modality == "image"
200
201
202
203
204
205
206
207
208
209
210
211

    llm = LLM(
        model="Qwen/Qwen-VL",
        trust_remote_code=True,
        max_num_seqs=5,
    )

    prompt = f"{question}Picture 1: <img></img>\n"
    stop_token_ids = None
    return llm, prompt, stop_token_ids


212
# Qwen2-VL
213
214
215
def run_qwen2_vl(question, modality):
    assert modality == "image"

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    model_name = "Qwen/Qwen2-VL-7B-Instruct"

    llm = LLM(
        model=model_name,
        max_num_seqs=5,
    )

    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
              "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
              f"{question}<|im_end|>\n"
              "<|im_start|>assistant\n")
    stop_token_ids = None
    return llm, prompt, stop_token_ids


231
232
233
model_example_map = {
    "llava": run_llava,
    "llava-next": run_llava_next,
234
    "llava-next-video": run_llava_next_video,
235
    "llava-onevision": run_llava_onevision,
236
237
238
239
240
    "fuyu": run_fuyu,
    "phi3_v": run_phi3v,
    "paligemma": run_paligemma,
    "chameleon": run_chameleon,
    "minicpmv": run_minicpmv,
241
    "blip-2": run_blip2,
242
    "internvl_chat": run_internvl,
243
    "qwen_vl": run_qwen_vl,
244
    "qwen2_vl": run_qwen2_vl,
245
246
247
}


248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
        img_question = "What is the content of this image?"

        return {
            "data": image,
            "question": img_question,
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
        vid_question = "Why is this video funny?"

        return {
            "data": video,
            "question": vid_question,
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


281
282
283
284
285
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

286
287
288
289
290
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
    question = mm_input["question"]

291
    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
292
293
294

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
295
296
297
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
                                     stop_token_ids=stop_token_ids)
298
299
300
301
302
303
304

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
            "prompt": prompt,
            "multi_modal_data": {
305
                modality: data
306
307
308
309
310
311
312
313
            },
        }

    else:
        # Batch inference
        inputs = [{
            "prompt": prompt,
            "multi_modal_data": {
314
                modality: data
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            },
        } for _ in range(args.num_prompts)]

    outputs = llm.generate(inputs, sampling_params=sampling_params)

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
        'vision language models')
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
337
                        default=4,
338
                        help='Number of prompts to run.')
339
340
341
    parser.add_argument('--modality',
                        type=str,
                        default="image",
342
                        choices=['image', 'video'],
343
344
345
346
347
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
348
349
    args = parser.parse_args()
    main(args)