vision_language.py 32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
Cyrus Leung's avatar
Cyrus Leung committed
3
4
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
5
6
7
8

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
9
import os
10
import random
11
12
from dataclasses import asdict
from typing import NamedTuple, Optional
13

14
from huggingface_hub import snapshot_download
15
16
from transformers import AutoTokenizer

17
from vllm import LLM, EngineArgs, SamplingParams
18
from vllm.assets.image import ImageAsset
19
from vllm.assets.video import VideoAsset
20
from vllm.lora.request import LoRARequest
21
22
from vllm.utils import FlexibleArgumentParser

23
24
25
26
27
28
29
30

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: list[str]
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


31
32
33
34
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

35

36
# Aria
37
def run_aria(questions: list[str], modality: str) -> ModelRequestData:
38
39
40
    assert modality == "image"
    model_name = "rhymes-ai/Aria"

41
    # NOTE: Need L40 (or equivalent) to avoid OOM
42
43
44
45
46
47
48
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
49

50
51
52
    prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
                "<|im_end|>\n<|im_start|>assistant\n")
               for question in questions]
53
54

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
55
56
57
58
59
60

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
61
62
63


# BLIP-2
64
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
65
66
67
68
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
69
    prompts = [f"Question: {question} Answer:" for question in questions]
70
71
72
73
74
75
76
77
78
    engine_args = EngineArgs(
        model="Salesforce/blip2-opt-2.7b",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
79
80
81


# Chameleon
82
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
83
84
    assert modality == "image"

85
    prompts = [f"{question}<image>" for question in questions]
86
87
88
89
90
91
92
93
94
95
96
    engine_args = EngineArgs(
        model="facebook/chameleon-7b",
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
97
98


99
# Deepseek-VL2
100
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
101
102
    assert modality == "image"

103
    model_name = "deepseek-ai/deepseek-vl2-tiny"
104

105
106
107
108
109
110
111
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
        hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
    )
112

113
114
115
116
    prompts = [
        f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
        for question in questions
    ]
117
118
119
120
121

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
122
123


124
# Florence2
125
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
126
127
    assert modality == "image"

128
129
130
131
132
133
134
135
    engine_args = EngineArgs(
        model="microsoft/Florence-2-large",
        tokenizer="facebook/bart-large",
        max_num_seqs=8,
        trust_remote_code=True,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
136

137
138
139
140
141
142
    prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
143
144


145
# Fuyu
146
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
147
148
    assert modality == "image"

149
    prompts = [f"{question}\n" for question in questions]
150
151
152
153
154
155
156
157
158
159
160
    engine_args = EngineArgs(
        model="adept/fuyu-8b",
        max_model_len=2048,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
161
162


163
# Gemma 3
164
def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
165
166
167
    assert modality == "image"
    model_name = "google/gemma-3-4b-it"

168
    engine_args = EngineArgs(
169
170
171
172
173
174
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        mm_processor_kwargs={"do_pan_and_scan": True},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
175
176
177
178

    prompts = [("<bos><start_of_turn>user\n"
                f"<start_of_image>{question}<end_of_turn>\n"
                "<start_of_turn>model\n") for question in questions]
179
180
181
182
183

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
184
185


186
# GLM-4v
187
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
188
189
190
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

191
192
193
194
195
196
197
198
199
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        trust_remote_code=True,
        enforce_eager=True,
        hf_overrides={"architectures": ["GLM4VForCausalLM"]},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
200

201
202
203
204
    prompts = [
        f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
        {question}<|assistant|>" for question in questions
    ]
205

206
    stop_token_ids = [151329, 151336, 151338]
207
208
209
210
211
212

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
213
214
215


# H2OVL-Mississippi
216
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
217
218
    assert modality == "image"

219
    model_name = "h2oai/h2ovl-mississippi-800m"
220

221
    engine_args = EngineArgs(
222
223
224
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
225
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
226
227
228
229
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
230
231
232
233
234
235
236
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
237
238

    # Stop tokens for H2OVL-Mississippi
239
    # https://huggingface.co/h2oai/h2ovl-mississippi-800m
240
    stop_token_ids = [tokenizer.eos_token_id]
241
242
243
244
245
246

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
247
248
249


# Idefics3-8B-Llama3
250
def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
251
252
253
    assert modality == "image"
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

254
    engine_args = EngineArgs(
255
256
257
258
259
260
261
262
263
264
265
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
266
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
267
    )
268
    prompts = [(
269
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
270
    ) for question in questions]
271
272
273
274
275

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
276
277
278


# InternVL
279
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
280
281
282
283
    assert modality == "image"

    model_name = "OpenGVLab/InternVL2-2B"

284
    engine_args = EngineArgs(
285
286
287
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
288
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
289
290
291
292
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
293
294
295
296
297
298
299
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
300
301
302
303
304
305
306

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
307
308
309
310
311
312

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
313
314


315
# LLaVA-1.5
316
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
317
    assert modality == "image"
318

319
320
321
    prompts = [
        f"USER: <image>\n{question}\nASSISTANT:" for question in questions
    ]
322

323
324
325
326
327
328
329
330
331
332
    engine_args = EngineArgs(
        model="llava-hf/llava-1.5-7b-hf",
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
333
334
335


# LLaVA-1.6/LLaVA-NeXT
336
def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
337
    assert modality == "image"
338

339
    prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
340
341
342
343
344
345
346
347
348
349
    engine_args = EngineArgs(
        model="llava-hf/llava-v1.6-mistral-7b-hf",
        max_model_len=8192,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
350
351
352
353


# LlaVA-NeXT-Video
# Currently only support for video input
354
355
def run_llava_next_video(questions: list[str],
                         modality: str) -> ModelRequestData:
356
357
    assert modality == "video"

358
359
360
    prompts = [
        f"USER: <video>\n{question} ASSISTANT:" for question in questions
    ]
361
362
363
    engine_args = EngineArgs(
        model="llava-hf/LLaVA-NeXT-Video-7B-hf",
        max_model_len=8192,
364
        max_num_seqs=2,
365
366
367
368
369
370
371
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
372
373


374
# LLaVA-OneVision
375
376
def run_llava_onevision(questions: list[str],
                        modality: str) -> ModelRequestData:
377
378

    if modality == "video":
379
380
381
382
        prompts = [
            f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
383
384

    elif modality == "image":
385
386
387
388
        prompts = [
            f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
389

390
391
392
393
394
395
396
397
398
399
    engine_args = EngineArgs(
        model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
        max_model_len=16384,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
400
401


402
# Mantis
403
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
404
    assert modality == "image"
405

406
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
407
408
409
410
    prompts = [
        llama3_template.format(f"{question}\n<image>")
        for question in questions
    ]
411

412
    engine_args = EngineArgs(
413
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
414
        max_model_len=4096,
415
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
416
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
417
    )
418
    stop_token_ids = [128009]
419
420
421
422
423
424

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
425
426
427


# MiniCPM-V
428
def run_minicpmv_base(questions: list[str], modality: str, model_name):
429
430
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
431
432
433
434
435
436
437

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
438
439
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

440
    # 2.6
441
442
443
444
445
446
447
448
449
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
450
451
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
452
    engine_args = EngineArgs(
453
        model=model_name,
454
455
        max_model_len=4096,
        max_num_seqs=2,
456
        trust_remote_code=True,
457
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
458
    )
459
460
461
462
463
464
465
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

466
    # 2.6 / o2.6
467
468
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
469

470
471
472
473
474
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

475
476
477
478
479
480
481
482
483
    prompts = [
        tokenizer.apply_chat_template(
            [{
                'role': 'user',
                'content': f"{modality_placeholder[modality]}\n{question}"
            }],
            tokenize=False,
            add_generation_prompt=True) for question in questions
    ]
484
485
486
487
488
489

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
490
491


492
def run_minicpmo(questions: list[str], modality: str) -> ModelRequestData:
493
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
494
495


496
def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
497
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
498
499


500
# LLama 3.2
501
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
502
503
    assert modality == "image"

504
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
505

506
507
508
509
510
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
511
    engine_args = EngineArgs(
512
        model=model_name,
513
514
        max_model_len=4096,
        max_num_seqs=16,
515
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
516
517
    )

518
    tokenizer = AutoTokenizer.from_pretrained(model_name)
519
    messages = [[{
520
521
522
523
524
525
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
526
            "text": question
527
        }]
528
    }] for question in questions]
529
530
531
    prompts = tokenizer.apply_chat_template(messages,
                                            add_generation_prompt=True,
                                            tokenize=False)
532
533
534
535
536

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
537
538


539
# Molmo
540
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
541
542
    assert modality == "image"

543
    model_name = "allenai/Molmo-7B-D-0924"
544

545
    engine_args = EngineArgs(
546
        model=model_name,
547
        trust_remote_code=True,
548
        dtype="bfloat16",
549
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
550
    )
551

552
553
554
555
    prompts = [
        f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
    ]
556
557
558
559
560

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
561
562


563
# NVLM-D
564
def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
565
566
567
568
569
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
570
    engine_args = EngineArgs(
571
572
573
574
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
575
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
576
577
578
579
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
580
    messages = [[{
581
582
        'role': 'user',
        'content': f"<image>\n{question}"
583
    }] for question in questions]
584
585
586
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
587
588
589
590
591

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
592
593


594
# PaliGemma
595
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
596
    assert modality == "image"
597

598
    # PaliGemma has special prompt format for VQA
599
600
601
602
603
604
605
606
607
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma-3b-mix-224",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
608
609


610
# PaliGemma 2
611
def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
612
    assert modality == "image"
613

614
    # PaliGemma 2 has special prompt format for VQA
615
616
617
618
619
620
621
622
623
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma2-3b-ft-docci-448",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
624
625


626
# Phi-3-Vision
627
def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
628
629
    assert modality == "image"

630
631
632
633
    prompts = [
        f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
        for question in questions
    ]
634

635
636
637
638
639
640
641
642
643
644
645
646
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
647
    engine_args = EngineArgs(
648
649
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
650
        max_model_len=4096,
651
        max_num_seqs=2,
652
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
653
        mm_processor_kwargs={"num_crops": 16},
654
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
655
    )
656
657
658
659
660

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
661
662


663
# Phi-4-multimodal-instruct
664
def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
665
666
667
668
669
670
671
672
673
674
675
676
677
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process image inputs.
    """
    assert modality == "image"
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    vision_lora_path = os.path.join(model_path, "vision-lora")
    prompts = [
        f"<|user|><|image_1|>{question}<|end|><|assistant|>"
        for question in questions
    ]
678
    engine_args = EngineArgs(
679
680
681
682
683
684
685
686
        model=model_path,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
    )

687
688
689
690
691
    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
    )
692
693


694
# Pixtral HF-format
695
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
696
697
698
699
    assert modality == "image"

    model_name = "mistral-community/pixtral-12b"

700
    # NOTE: Need L40 (or equivalent) to avoid OOM
701
    engine_args = EngineArgs(
702
703
        model=model_name,
        max_model_len=8192,
704
        max_num_seqs=2,
705
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
706
707
    )

708
    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
709
710
711
712
713

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
714
715


716
# Qwen
717
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
718
719
    assert modality == "image"

720
    engine_args = EngineArgs(
721
        model="Qwen/Qwen-VL",
722
        trust_remote_code=True,
723
724
        max_model_len=1024,
        max_num_seqs=2,
725
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
726
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
727
728
    )

729
    prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
730
731
732
733
734

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
735
736


737
# Qwen2-VL
738
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
739

740
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
741

742
    engine_args = EngineArgs(
743
        model=model_name,
744
745
746
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
747
        mm_processor_kwargs={
748
749
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
750
        },
751
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
752
    )
753

754
755
756
757
758
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

759
760
761
762
763
764
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
765
766
767
768
769

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
770
771


Roger Wang's avatar
Roger Wang committed
772
# Qwen2.5-VL
773
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
Roger Wang's avatar
Roger Wang committed
774
775
776

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

777
    engine_args = EngineArgs(
Roger Wang's avatar
Roger Wang committed
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

794
795
796
797
798
799
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
800
801
802
803
804

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
Roger Wang's avatar
Roger Wang committed
805
806


807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"

    model_name = "Skywork/Skywork-R1V-38B"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)

    # Stop tokens for SkyworkR1V
    # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
    stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )


842
model_example_map = {
843
844
845
    "aria": run_aria,
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
846
    "deepseek_vl_v2": run_deepseek_vl2,
847
    "florence2": run_florence2,
848
    "fuyu": run_fuyu,
849
    "gemma3": run_gemma3,
850
851
852
853
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
854
855
    "llava": run_llava,
    "llava-next": run_llava_next,
856
    "llava-next-video": run_llava_next_video,
857
    "llava-onevision": run_llava_onevision,
858
    "mantis": run_mantis,
859
    "minicpmo": run_minicpmo,
860
    "minicpmv": run_minicpmv,
861
862
    "mllama": run_mllama,
    "molmo": run_molmo,
863
    "NVLM_D": run_nvlm_d,
864
865
866
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
867
    "phi4_mm": run_phi4mm,
868
    "pixtral_hf": run_pixtral_hf,
869
    "qwen_vl": run_qwen_vl,
870
    "qwen2_vl": run_qwen2_vl,
Roger Wang's avatar
Roger Wang committed
871
    "qwen2_5_vl": run_qwen2_5_vl,
872
    "skywork_chat": run_skyworkr1v,
873
874
875
}


876
877
878
879
880
881
882
883
884
885
886
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
887
888
889
890
891
892
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "What's in the image?",
            "Where is this image taken?",
        ]
893
894
895

        return {
            "data": image,
896
            "questions": img_questions,
897
898
899
900
901
902
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
903
        vid_questions = ["Why is this video funny?"]
904
905
906

        return {
            "data": video,
907
            "questions": vid_questions,
908
909
910
911
912
913
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


914
915
def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
935
            "prompt": prompts[i % len(prompts)],
936
937
938
939
940
941
942
943
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


944
945
946
947
948
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

949
950
951
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
952
    questions = mm_input["questions"]
953

954
955
956
957
958
959
960
961
962
963
964
965
    req_data = model_example_map[model](questions, modality)

    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    # To maintain code compatibility in this script, we add LoRA here.
    # You can also add LoRA using:
    # llm.generate(prompts, lora_request=lora_request,...)
    if req_data.lora_requests:
        for lora_request in req_data.lora_requests:
            llm.llm_engine.add_lora(lora_request=lora_request)

966
    # Don't want to check the flag multiple times, so just hijack `prompts`.
967
968
    prompts = req_data.prompts if args.use_different_prompt_per_request else [
        req_data.prompts[0]
969
    ]
970
971
972

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
973
974
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
975
                                     stop_token_ids=req_data.stop_token_ids)
976
977
978
979
980

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
981
            "prompt": prompts[0],
982
            "multi_modal_data": {
983
                modality: data
984
985
986
987
            },
        }
    else:
        # Batch inference
988
989
990
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
991
                                        args.num_prompts, data, prompts,
992
993
994
995
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
996
                "prompt": prompts[i % len(prompts)],
997
998
999
                "multi_modal_data": {
                    modality: data
                },
1000
            } for i in range(args.num_prompts)]
1001
1002
1003
1004
1005
1006
1007

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
1008

1009
1010
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
1011
1012
1013
1014
1015
1016
1017
1018
1019

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
1020
        'vision language models for text generation')
1021
1022
1023
1024
1025
1026
1027
1028
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
1029
                        default=4,
1030
                        help='Number of prompts to run.')
1031
1032
1033
    parser.add_argument('--modality',
                        type=str,
                        default="image",
1034
                        choices=['image', 'video'],
1035
1036
1037
1038
1039
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
1040
1041
1042
1043
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help="Set the seed when initializing `vllm.LLM`.")
1044
1045
1046
1047
1048
1049
1050
1051
1052

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
1053
        '--disable-mm-preprocessor-cache',
1054
        action='store_true',
1055
        help='If True, disables caching of multi-modal preprocessor/mapper.')
1056
1057
1058
1059
1060
1061

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

1062
1063
1064
1065
1066
1067
    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

1068
    args = parser.parse_args()
1069
    main(args)