vision_language.py 32.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
Cyrus Leung's avatar
Cyrus Leung committed
3
4
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
5
6
7
8

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
9
import os
10
import random
11
12
from dataclasses import asdict
from typing import NamedTuple, Optional
13

14
from huggingface_hub import snapshot_download
15
16
from transformers import AutoTokenizer

17
from vllm import LLM, EngineArgs, SamplingParams
18
from vllm.assets.image import ImageAsset
19
from vllm.assets.video import VideoAsset
20
from vllm.lora.request import LoRARequest
21
22
from vllm.utils import FlexibleArgumentParser

23
24
25
26
27
28
29
30

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: list[str]
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


31
32
33
34
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

35

36
# Aria
37
def run_aria(questions: list[str], modality: str) -> ModelRequestData:
38
39
40
    assert modality == "image"
    model_name = "rhymes-ai/Aria"

41
    # NOTE: Need L40 (or equivalent) to avoid OOM
42
43
44
45
46
47
48
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
49

50
51
52
    prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
                "<|im_end|>\n<|im_start|>assistant\n")
               for question in questions]
53
54

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
55
56
57
58
59
60

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
61
62
63


# BLIP-2
64
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
65
66
67
68
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
69
    prompts = [f"Question: {question} Answer:" for question in questions]
70
    engine_args = EngineArgs(
71
        model="Salesforce/blip2-opt-6.7b",
72
73
74
75
76
77
78
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
79
80
81


# Chameleon
82
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
83
84
    assert modality == "image"

85
    prompts = [f"{question}<image>" for question in questions]
86
87
88
89
90
91
92
93
94
95
96
    engine_args = EngineArgs(
        model="facebook/chameleon-7b",
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
97
98


99
# Deepseek-VL2
100
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
101
102
    assert modality == "image"

103
    model_name = "deepseek-ai/deepseek-vl2-tiny"
104

105
106
107
108
109
110
111
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
        hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
    )
112

113
114
115
116
    prompts = [
        f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
        for question in questions
    ]
117
118
119
120
121

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
122
123


124
# Florence2
125
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
126
127
    assert modality == "image"

128
129
130
    engine_args = EngineArgs(
        model="microsoft/Florence-2-large",
        tokenizer="facebook/bart-large",
131
132
        max_model_len=4096,
        max_num_seqs=2,
133
134
135
136
        trust_remote_code=True,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
137

138
139
140
141
142
143
    prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
144
145


146
# Fuyu
147
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
148
149
    assert modality == "image"

150
    prompts = [f"{question}\n" for question in questions]
151
152
153
154
155
156
157
158
159
160
161
    engine_args = EngineArgs(
        model="adept/fuyu-8b",
        max_model_len=2048,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
162
163


164
# Gemma 3
165
def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
166
167
168
    assert modality == "image"
    model_name = "google/gemma-3-4b-it"

169
    engine_args = EngineArgs(
170
171
172
173
174
175
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        mm_processor_kwargs={"do_pan_and_scan": True},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
176
177
178
179

    prompts = [("<bos><start_of_turn>user\n"
                f"<start_of_image>{question}<end_of_turn>\n"
                "<start_of_turn>model\n") for question in questions]
180
181
182
183
184

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
185
186


187
# GLM-4v
188
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
189
190
191
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

192
193
194
195
196
197
198
199
200
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        trust_remote_code=True,
        enforce_eager=True,
        hf_overrides={"architectures": ["GLM4VForCausalLM"]},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
201

202
203
204
205
    prompts = [
        f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
        {question}<|assistant|>" for question in questions
    ]
206

207
    stop_token_ids = [151329, 151336, 151338]
208
209
210
211
212
213

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
214
215
216


# H2OVL-Mississippi
217
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
218
219
    assert modality == "image"

220
    model_name = "h2oai/h2ovl-mississippi-800m"
221

222
    engine_args = EngineArgs(
223
224
225
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
226
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
227
228
229
230
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
231
232
233
234
235
236
237
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
238
239

    # Stop tokens for H2OVL-Mississippi
240
    # https://huggingface.co/h2oai/h2ovl-mississippi-800m
241
    stop_token_ids = [tokenizer.eos_token_id]
242
243
244
245
246
247

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
248
249
250


# Idefics3-8B-Llama3
251
def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
252
253
254
    assert modality == "image"
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

255
    engine_args = EngineArgs(
256
257
258
259
260
261
262
263
264
265
266
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
267
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
268
    )
269
    prompts = [(
270
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
271
    ) for question in questions]
272
273
274
275
276

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
277
278
279


# InternVL
280
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
281
282
283
284
    assert modality == "image"

    model_name = "OpenGVLab/InternVL2-2B"

285
    engine_args = EngineArgs(
286
287
288
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
289
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
290
291
292
293
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
294
295
296
297
298
299
300
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
301
302
303
304
305
306
307

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
308
309
310
311
312
313

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
314
315


316
# LLaVA-1.5
317
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
318
    assert modality == "image"
319

320
321
322
    prompts = [
        f"USER: <image>\n{question}\nASSISTANT:" for question in questions
    ]
323

324
325
326
327
328
329
330
331
332
333
    engine_args = EngineArgs(
        model="llava-hf/llava-1.5-7b-hf",
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
334
335
336


# LLaVA-1.6/LLaVA-NeXT
337
def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
338
    assert modality == "image"
339

340
    prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
341
342
343
344
345
346
347
348
349
350
    engine_args = EngineArgs(
        model="llava-hf/llava-v1.6-mistral-7b-hf",
        max_model_len=8192,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
351
352
353
354


# LlaVA-NeXT-Video
# Currently only support for video input
355
356
def run_llava_next_video(questions: list[str],
                         modality: str) -> ModelRequestData:
357
358
    assert modality == "video"

359
360
361
    prompts = [
        f"USER: <video>\n{question} ASSISTANT:" for question in questions
    ]
362
363
364
    engine_args = EngineArgs(
        model="llava-hf/LLaVA-NeXT-Video-7B-hf",
        max_model_len=8192,
365
        max_num_seqs=2,
366
367
368
369
370
371
372
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
373
374


375
# LLaVA-OneVision
376
377
def run_llava_onevision(questions: list[str],
                        modality: str) -> ModelRequestData:
378
379

    if modality == "video":
380
381
382
383
        prompts = [
            f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
384
385

    elif modality == "image":
386
387
388
389
        prompts = [
            f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
390

391
392
393
394
395
396
397
398
399
400
    engine_args = EngineArgs(
        model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
        max_model_len=16384,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
401
402


403
# Mantis
404
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
405
    assert modality == "image"
406

407
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
408
409
410
411
    prompts = [
        llama3_template.format(f"{question}\n<image>")
        for question in questions
    ]
412

413
    engine_args = EngineArgs(
414
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
415
        max_model_len=4096,
416
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
417
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
418
    )
419
    stop_token_ids = [128009]
420
421
422
423
424
425

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
426
427
428


# MiniCPM-V
429
def run_minicpmv_base(questions: list[str], modality: str, model_name):
430
431
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
432
433
434
435
436
437
438

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
439
440
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

441
    # 2.6
442
443
444
445
446
447
448
449
450
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
451
452
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
453
    engine_args = EngineArgs(
454
        model=model_name,
455
456
        max_model_len=4096,
        max_num_seqs=2,
457
        trust_remote_code=True,
458
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
459
    )
460
461
462
463
464
465
466
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

467
    # 2.6 / o2.6
468
469
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
470

471
472
473
474
475
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

476
477
478
479
480
481
482
483
484
    prompts = [
        tokenizer.apply_chat_template(
            [{
                'role': 'user',
                'content': f"{modality_placeholder[modality]}\n{question}"
            }],
            tokenize=False,
            add_generation_prompt=True) for question in questions
    ]
485
486
487
488
489
490

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
491
492


493
def run_minicpmo(questions: list[str], modality: str) -> ModelRequestData:
494
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
495
496


497
def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
498
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
499
500


501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"

    model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

    # NOTE: Need L40 (or equivalent) to avoid OOM
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        tensor_parallel_size=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


524
# LLama 3.2
525
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
526
527
    assert modality == "image"

528
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
529

530
531
532
533
534
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
535
    engine_args = EngineArgs(
536
        model=model_name,
537
        max_model_len=4096,
538
        max_num_seqs=2,
539
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
540
541
    )

542
    tokenizer = AutoTokenizer.from_pretrained(model_name)
543
    messages = [[{
544
545
546
547
548
549
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
550
            "text": question
551
        }]
552
    }] for question in questions]
553
554
555
    prompts = tokenizer.apply_chat_template(messages,
                                            add_generation_prompt=True,
                                            tokenize=False)
556
557
558
559
560

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
561
562


563
# Molmo
564
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
565
566
    assert modality == "image"

567
    model_name = "allenai/Molmo-7B-D-0924"
568

569
    engine_args = EngineArgs(
570
        model=model_name,
571
        trust_remote_code=True,
572
        dtype="bfloat16",
573
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
574
    )
575

576
577
578
579
    prompts = [
        f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
    ]
580
581
582
583
584

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
585
586


587
# NVLM-D
588
def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
589
590
591
592
593
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
594
    engine_args = EngineArgs(
595
596
597
598
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
599
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
600
601
602
603
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
604
    messages = [[{
605
606
        'role': 'user',
        'content': f"<image>\n{question}"
607
    }] for question in questions]
608
609
610
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
611
612
613
614
615

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
616
617


618
# PaliGemma
619
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
620
    assert modality == "image"
621

622
    # PaliGemma has special prompt format for VQA
623
624
625
626
627
628
629
630
631
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma-3b-mix-224",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
632
633


634
# PaliGemma 2
635
def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
636
    assert modality == "image"
637

638
    # PaliGemma 2 has special prompt format for VQA
639
640
641
642
643
644
645
646
647
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma2-3b-ft-docci-448",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
648
649


650
# Phi-3-Vision
651
def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
652
653
    assert modality == "image"

654
655
656
657
    prompts = [
        f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
        for question in questions
    ]
658

659
660
661
662
663
664
665
666
667
668
669
670
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
671
    engine_args = EngineArgs(
672
673
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
674
        max_model_len=4096,
675
        max_num_seqs=2,
676
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
677
        mm_processor_kwargs={"num_crops": 16},
678
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
679
    )
680
681
682
683
684

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
685
686


687
# Phi-4-multimodal-instruct
688
def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
689
690
691
692
693
694
695
696
697
698
699
700
701
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process image inputs.
    """
    assert modality == "image"
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    vision_lora_path = os.path.join(model_path, "vision-lora")
    prompts = [
        f"<|user|><|image_1|>{question}<|end|><|assistant|>"
        for question in questions
    ]
702
    engine_args = EngineArgs(
703
704
705
706
707
708
709
710
        model=model_path,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
    )

711
712
713
714
715
    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
    )
716
717


718
# Pixtral HF-format
719
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
720
721
722
723
    assert modality == "image"

    model_name = "mistral-community/pixtral-12b"

724
    # NOTE: Need L40 (or equivalent) to avoid OOM
725
    engine_args = EngineArgs(
726
        model=model_name,
727
        max_model_len=6144,
728
        max_num_seqs=2,
729
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
730
731
    )

732
    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
733
734
735
736
737

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
738
739


740
# Qwen
741
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
742
743
    assert modality == "image"

744
    engine_args = EngineArgs(
745
        model="Qwen/Qwen-VL",
746
        trust_remote_code=True,
747
748
        max_model_len=1024,
        max_num_seqs=2,
749
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
750
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
751
752
    )

753
    prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
754
755
756
757
758

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
759
760


761
# Qwen2-VL
762
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
763

764
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
765

766
    engine_args = EngineArgs(
767
        model=model_name,
768
769
770
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
771
        mm_processor_kwargs={
772
773
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
774
        },
775
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
776
    )
777

778
779
780
781
782
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

783
784
785
786
787
788
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
789
790
791
792
793

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
794
795


Roger Wang's avatar
Roger Wang committed
796
# Qwen2.5-VL
797
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
Roger Wang's avatar
Roger Wang committed
798
799
800

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

801
    engine_args = EngineArgs(
Roger Wang's avatar
Roger Wang committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

818
819
820
821
822
823
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
824
825
826
827
828

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
Roger Wang's avatar
Roger Wang committed
829
830


831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"

    model_name = "Skywork/Skywork-R1V-38B"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)

    # Stop tokens for SkyworkR1V
    # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
    stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )


866
model_example_map = {
867
868
869
    "aria": run_aria,
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
870
    "deepseek_vl_v2": run_deepseek_vl2,
871
    "florence2": run_florence2,
872
    "fuyu": run_fuyu,
873
    "gemma3": run_gemma3,
874
875
876
877
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
878
879
    "llava": run_llava,
    "llava-next": run_llava_next,
880
    "llava-next-video": run_llava_next_video,
881
    "llava-onevision": run_llava_onevision,
882
    "mantis": run_mantis,
883
    "minicpmo": run_minicpmo,
884
    "minicpmv": run_minicpmv,
885
    "mistral3": run_mistral3,
886
887
    "mllama": run_mllama,
    "molmo": run_molmo,
888
    "NVLM_D": run_nvlm_d,
889
890
891
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
892
    "phi4_mm": run_phi4mm,
893
    "pixtral_hf": run_pixtral_hf,
894
    "qwen_vl": run_qwen_vl,
895
    "qwen2_vl": run_qwen2_vl,
Roger Wang's avatar
Roger Wang committed
896
    "qwen2_5_vl": run_qwen2_5_vl,
897
    "skywork_chat": run_skyworkr1v,
898
899
900
}


901
902
903
904
905
906
907
908
909
910
911
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
912
913
914
915
916
917
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "What's in the image?",
            "Where is this image taken?",
        ]
918
919
920

        return {
            "data": image,
921
            "questions": img_questions,
922
923
924
925
926
927
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
928
        vid_questions = ["Why is this video funny?"]
929
930
931

        return {
            "data": video,
932
            "questions": vid_questions,
933
934
935
936
937
938
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


939
940
def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
960
            "prompt": prompts[i % len(prompts)],
961
962
963
964
965
966
967
968
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


969
970
971
972
973
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

974
975
976
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
977
    questions = mm_input["questions"]
978

979
980
981
982
983
984
985
986
987
988
989
990
    req_data = model_example_map[model](questions, modality)

    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    # To maintain code compatibility in this script, we add LoRA here.
    # You can also add LoRA using:
    # llm.generate(prompts, lora_request=lora_request,...)
    if req_data.lora_requests:
        for lora_request in req_data.lora_requests:
            llm.llm_engine.add_lora(lora_request=lora_request)

991
    # Don't want to check the flag multiple times, so just hijack `prompts`.
992
993
    prompts = req_data.prompts if args.use_different_prompt_per_request else [
        req_data.prompts[0]
994
    ]
995
996
997

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
998
999
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
1000
                                     stop_token_ids=req_data.stop_token_ids)
1001
1002
1003
1004
1005

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
1006
            "prompt": prompts[0],
1007
            "multi_modal_data": {
1008
                modality: data
1009
1010
1011
1012
            },
        }
    else:
        # Batch inference
1013
1014
1015
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
1016
                                        args.num_prompts, data, prompts,
1017
1018
1019
1020
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
1021
                "prompt": prompts[i % len(prompts)],
1022
1023
1024
                "multi_modal_data": {
                    modality: data
                },
1025
            } for i in range(args.num_prompts)]
1026
1027
1028
1029
1030
1031
1032

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
1033

1034
1035
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
1036
1037
1038
1039
1040
1041
1042
1043
1044

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
1045
        'vision language models for text generation')
1046
1047
1048
1049
1050
1051
1052
1053
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
1054
                        default=4,
1055
                        help='Number of prompts to run.')
1056
1057
1058
    parser.add_argument('--modality',
                        type=str,
                        default="image",
1059
                        choices=['image', 'video'],
1060
1061
1062
1063
1064
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
1065
1066
1067
1068
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help="Set the seed when initializing `vllm.LLM`.")
1069
1070
1071
1072
1073
1074
1075
1076
1077

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
1078
        '--disable-mm-preprocessor-cache',
1079
        action='store_true',
1080
        help='If True, disables caching of multi-modal preprocessor/mapper.')
1081
1082
1083
1084
1085
1086

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

1087
1088
1089
1090
1091
1092
    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

1093
    args = parser.parse_args()
1094
    main(args)