vision_language.py 33.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
Cyrus Leung's avatar
Cyrus Leung committed
3
4
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
5
6
7
8

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
9
import os
10
import random
11
12
from dataclasses import asdict
from typing import NamedTuple, Optional
13

14
from huggingface_hub import snapshot_download
15
16
from transformers import AutoTokenizer

17
from vllm import LLM, EngineArgs, SamplingParams
18
from vllm.assets.image import ImageAsset
19
from vllm.assets.video import VideoAsset
20
from vllm.lora.request import LoRARequest
21
22
from vllm.utils import FlexibleArgumentParser

23
24
25
26
27
28
29
30

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: list[str]
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


31
32
33
34
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

35

36
# Aria
37
def run_aria(questions: list[str], modality: str) -> ModelRequestData:
38
39
40
    assert modality == "image"
    model_name = "rhymes-ai/Aria"

41
    # NOTE: Need L40 (or equivalent) to avoid OOM
42
43
44
45
46
47
48
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
49

50
51
52
    prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
                "<|im_end|>\n<|im_start|>assistant\n")
               for question in questions]
53
54

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
55
56
57
58
59
60

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
61
62


Jennifer Zhao's avatar
Jennifer Zhao committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Aya Vision
def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"
    model_name = "CohereForAI/aya-vision-8b"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        mm_processor_kwargs={"crop_to_patches": True},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
    prompts = [
        f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
        for question in questions
    ]
    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


85
# BLIP-2
86
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
87
88
89
90
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
91
    prompts = [f"Question: {question} Answer:" for question in questions]
92
    engine_args = EngineArgs(
93
        model="Salesforce/blip2-opt-6.7b",
94
95
96
97
98
99
100
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
101
102
103


# Chameleon
104
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
105
106
    assert modality == "image"

107
    prompts = [f"{question}<image>" for question in questions]
108
109
110
111
112
113
114
115
116
117
118
    engine_args = EngineArgs(
        model="facebook/chameleon-7b",
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
119
120


121
# Deepseek-VL2
122
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
123
124
    assert modality == "image"

125
    model_name = "deepseek-ai/deepseek-vl2-tiny"
126

127
128
129
130
131
132
133
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
        hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
    )
134

135
136
137
138
    prompts = [
        f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
        for question in questions
    ]
139
140
141
142
143

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
144
145


146
# Florence2
147
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
148
149
    assert modality == "image"

150
151
152
    engine_args = EngineArgs(
        model="microsoft/Florence-2-large",
        tokenizer="facebook/bart-large",
153
154
        max_model_len=4096,
        max_num_seqs=2,
155
156
157
158
        trust_remote_code=True,
        dtype="bfloat16",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
159

160
161
162
163
164
165
    prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
166
167


168
# Fuyu
169
def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
170
171
    assert modality == "image"

172
    prompts = [f"{question}\n" for question in questions]
173
174
175
176
177
178
179
180
181
182
183
    engine_args = EngineArgs(
        model="adept/fuyu-8b",
        max_model_len=2048,
        max_num_seqs=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
184
185


186
# Gemma 3
187
def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
188
189
190
    assert modality == "image"
    model_name = "google/gemma-3-4b-it"

191
    engine_args = EngineArgs(
192
193
194
195
196
197
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        mm_processor_kwargs={"do_pan_and_scan": True},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
198
199
200
201

    prompts = [("<bos><start_of_turn>user\n"
                f"<start_of_image>{question}<end_of_turn>\n"
                "<start_of_turn>model\n") for question in questions]
202
203
204
205
206

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
207
208


209
# GLM-4v
210
def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
211
212
213
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

214
215
216
217
218
219
220
221
222
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_seqs=2,
        trust_remote_code=True,
        enforce_eager=True,
        hf_overrides={"architectures": ["GLM4VForCausalLM"]},
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )
223

224
225
226
227
    prompts = [
        f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
        {question}<|assistant|>" for question in questions
    ]
228

229
    stop_token_ids = [151329, 151336, 151338]
230
231
232
233
234
235

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
236
237
238


# H2OVL-Mississippi
239
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
240
241
    assert modality == "image"

242
    model_name = "h2oai/h2ovl-mississippi-800m"
243

244
    engine_args = EngineArgs(
245
246
247
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
248
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
249
250
251
252
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
253
254
255
256
257
258
259
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
260
261

    # Stop tokens for H2OVL-Mississippi
262
    # https://huggingface.co/h2oai/h2ovl-mississippi-800m
263
    stop_token_ids = [tokenizer.eos_token_id]
264
265
266
267
268
269

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
270
271
272


# Idefics3-8B-Llama3
273
def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
274
275
276
    assert modality == "image"
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

277
    engine_args = EngineArgs(
278
279
280
281
282
283
284
285
286
287
288
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
289
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
290
    )
291
    prompts = [(
292
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
293
    ) for question in questions]
294
295
296
297
298

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
299
300
301


# InternVL
302
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
303
304
305
306
    assert modality == "image"

    model_name = "OpenGVLab/InternVL2-2B"

307
    engine_args = EngineArgs(
308
309
310
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
311
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
312
313
314
315
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
316
317
318
319
320
321
322
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
323
324
325
326
327
328
329

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
330
331
332
333
334
335

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
336
337


338
# LLaVA-1.5
339
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
340
    assert modality == "image"
341

342
343
344
    prompts = [
        f"USER: <image>\n{question}\nASSISTANT:" for question in questions
    ]
345

346
347
348
349
350
351
352
353
354
355
    engine_args = EngineArgs(
        model="llava-hf/llava-1.5-7b-hf",
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
356
357
358


# LLaVA-1.6/LLaVA-NeXT
359
def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
360
    assert modality == "image"
361

362
    prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
363
364
365
366
367
368
369
370
371
372
    engine_args = EngineArgs(
        model="llava-hf/llava-v1.6-mistral-7b-hf",
        max_model_len=8192,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
373
374
375
376


# LlaVA-NeXT-Video
# Currently only support for video input
377
378
def run_llava_next_video(questions: list[str],
                         modality: str) -> ModelRequestData:
379
380
    assert modality == "video"

381
382
383
    prompts = [
        f"USER: <video>\n{question} ASSISTANT:" for question in questions
    ]
384
385
386
    engine_args = EngineArgs(
        model="llava-hf/LLaVA-NeXT-Video-7B-hf",
        max_model_len=8192,
387
        max_num_seqs=2,
388
389
390
391
392
393
394
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
395
396


397
# LLaVA-OneVision
398
399
def run_llava_onevision(questions: list[str],
                        modality: str) -> ModelRequestData:
400
401

    if modality == "video":
402
403
404
405
        prompts = [
            f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
406
407

    elif modality == "image":
408
409
410
411
        prompts = [
            f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
412

413
414
415
416
417
418
419
420
421
422
    engine_args = EngineArgs(
        model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
        max_model_len=16384,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
423
424


425
# Mantis
426
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
427
    assert modality == "image"
428

429
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
430
431
432
433
    prompts = [
        llama3_template.format(f"{question}\n<image>")
        for question in questions
    ]
434

435
    engine_args = EngineArgs(
436
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
437
        max_model_len=4096,
438
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
439
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
440
    )
441
    stop_token_ids = [128009]
442
443
444
445
446
447

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
448
449
450


# MiniCPM-V
451
def run_minicpmv_base(questions: list[str], modality: str, model_name):
452
453
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
454
455
456
457
458
459
460

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
461
462
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

463
    # 2.6
464
465
466
467
468
469
470
471
472
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
473
474
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
475
    engine_args = EngineArgs(
476
        model=model_name,
477
478
        max_model_len=4096,
        max_num_seqs=2,
479
        trust_remote_code=True,
480
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
481
    )
482
483
484
485
486
487
488
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

489
    # 2.6 / o2.6
490
491
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
492

493
494
495
496
497
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

498
499
500
501
502
503
504
505
506
    prompts = [
        tokenizer.apply_chat_template(
            [{
                'role': 'user',
                'content': f"{modality_placeholder[modality]}\n{question}"
            }],
            tokenize=False,
            add_generation_prompt=True) for question in questions
    ]
507
508
509
510
511
512

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )
513
514


515
def run_minicpmo(questions: list[str], modality: str) -> ModelRequestData:
516
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
517
518


519
def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
520
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
521
522


523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"

    model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"

    # NOTE: Need L40 (or equivalent) to avoid OOM
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        tensor_parallel_size=2,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


546
# LLama 3.2
547
def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
548
549
    assert modality == "image"

550
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
551

552
553
554
555
556
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
557
    engine_args = EngineArgs(
558
        model=model_name,
559
        max_model_len=4096,
560
        max_num_seqs=2,
561
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
562
563
    )

564
    tokenizer = AutoTokenizer.from_pretrained(model_name)
565
    messages = [[{
566
567
568
569
570
571
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
572
            "text": question
573
        }]
574
    }] for question in questions]
575
576
577
    prompts = tokenizer.apply_chat_template(messages,
                                            add_generation_prompt=True,
                                            tokenize=False)
578
579
580
581
582

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
583
584


585
# Molmo
586
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
587
588
    assert modality == "image"

589
    model_name = "allenai/Molmo-7B-D-0924"
590

591
    engine_args = EngineArgs(
592
        model=model_name,
593
        trust_remote_code=True,
594
        dtype="bfloat16",
595
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
596
    )
597

598
599
600
601
    prompts = [
        f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
    ]
602
603
604
605
606

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
607
608


609
# NVLM-D
610
def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
611
612
613
614
615
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
616
    engine_args = EngineArgs(
617
618
619
620
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
621
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
622
623
624
625
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
626
    messages = [[{
627
628
        'role': 'user',
        'content': f"<image>\n{question}"
629
    }] for question in questions]
630
631
632
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
633
634
635
636
637

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
638
639


640
# PaliGemma
641
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
642
    assert modality == "image"
643

644
    # PaliGemma has special prompt format for VQA
645
646
647
648
649
650
651
652
653
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma-3b-mix-224",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
654
655


656
# PaliGemma 2
657
def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
658
    assert modality == "image"
659

660
    # PaliGemma 2 has special prompt format for VQA
661
662
663
664
665
666
667
668
669
    prompts = ["caption en" for _ in questions]
    engine_args = EngineArgs(
        model="google/paligemma2-3b-ft-docci-448",
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
670
671


672
# Phi-3-Vision
673
def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
674
675
    assert modality == "image"

676
677
678
679
    prompts = [
        f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
        for question in questions
    ]
680

681
682
683
684
685
686
687
688
689
690
691
692
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
693
    engine_args = EngineArgs(
694
695
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
696
        max_model_len=4096,
697
        max_num_seqs=2,
698
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
699
        mm_processor_kwargs={"num_crops": 16},
700
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
701
    )
702
703
704
705
706

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
707
708


709
# Phi-4-multimodal-instruct
710
def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
711
712
713
714
715
716
717
718
719
720
721
722
723
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process image inputs.
    """
    assert modality == "image"
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    vision_lora_path = os.path.join(model_path, "vision-lora")
    prompts = [
        f"<|user|><|image_1|>{question}<|end|><|assistant|>"
        for question in questions
    ]
724
    engine_args = EngineArgs(
725
726
727
728
729
730
731
732
        model=model_path,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
    )

733
734
735
736
737
    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        lora_requests=[LoRARequest("vision", 1, vision_lora_path)],
    )
738
739


740
# Pixtral HF-format
741
def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
742
743
744
745
    assert modality == "image"

    model_name = "mistral-community/pixtral-12b"

746
    # NOTE: Need L40 (or equivalent) to avoid OOM
747
    engine_args = EngineArgs(
748
        model=model_name,
749
        max_model_len=6144,
750
        max_num_seqs=2,
751
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
752
753
    )

754
    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
755
756
757
758
759

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
760
761


762
# Qwen
763
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
764
765
    assert modality == "image"

766
    engine_args = EngineArgs(
767
        model="Qwen/Qwen-VL",
768
        trust_remote_code=True,
769
770
        max_model_len=1024,
        max_num_seqs=2,
771
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
772
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
773
774
    )

775
    prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
776
777
778
779
780

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
781
782


783
# Qwen2-VL
784
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
785

786
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
787

788
    engine_args = EngineArgs(
789
        model=model_name,
790
791
792
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
793
        mm_processor_kwargs={
794
795
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
796
        },
797
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
798
    )
799

800
801
802
803
804
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

805
806
807
808
809
810
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
811
812
813
814
815

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
816
817


Roger Wang's avatar
Roger Wang committed
818
# Qwen2.5-VL
819
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
Roger Wang's avatar
Roger Wang committed
820
821
822

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

823
    engine_args = EngineArgs(
Roger Wang's avatar
Roger Wang committed
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

840
841
842
843
844
845
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
846
847
848
849
850

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
Roger Wang's avatar
Roger Wang committed
851
852


853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
    assert modality == "image"

    model_name = "Skywork/Skywork-R1V-38B"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)

    # Stop tokens for SkyworkR1V
    # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
    stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )


888
model_example_map = {
889
    "aria": run_aria,
Jennifer Zhao's avatar
Jennifer Zhao committed
890
    "aya_vision": run_aya_vision,
891
892
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
893
    "deepseek_vl_v2": run_deepseek_vl2,
894
    "florence2": run_florence2,
895
    "fuyu": run_fuyu,
896
    "gemma3": run_gemma3,
897
898
899
900
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
901
902
    "llava": run_llava,
    "llava-next": run_llava_next,
903
    "llava-next-video": run_llava_next_video,
904
    "llava-onevision": run_llava_onevision,
905
    "mantis": run_mantis,
906
    "minicpmo": run_minicpmo,
907
    "minicpmv": run_minicpmv,
908
    "mistral3": run_mistral3,
909
910
    "mllama": run_mllama,
    "molmo": run_molmo,
911
    "NVLM_D": run_nvlm_d,
912
913
914
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
915
    "phi4_mm": run_phi4mm,
916
    "pixtral_hf": run_pixtral_hf,
917
    "qwen_vl": run_qwen_vl,
918
    "qwen2_vl": run_qwen2_vl,
Roger Wang's avatar
Roger Wang committed
919
    "qwen2_5_vl": run_qwen2_5_vl,
920
    "skywork_chat": run_skyworkr1v,
921
922
923
}


924
925
926
927
928
929
930
931
932
933
934
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
935
936
937
938
939
940
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "What's in the image?",
            "Where is this image taken?",
        ]
941
942
943

        return {
            "data": image,
944
            "questions": img_questions,
945
946
947
948
949
950
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
951
        vid_questions = ["Why is this video funny?"]
952
953
954

        return {
            "data": video,
955
            "questions": vid_questions,
956
957
958
959
960
961
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


962
963
def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
983
            "prompt": prompts[i % len(prompts)],
984
985
986
987
988
989
990
991
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


992
993
994
995
996
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

997
998
999
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
1000
    questions = mm_input["questions"]
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    req_data = model_example_map[model](questions, modality)

    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    # To maintain code compatibility in this script, we add LoRA here.
    # You can also add LoRA using:
    # llm.generate(prompts, lora_request=lora_request,...)
    if req_data.lora_requests:
        for lora_request in req_data.lora_requests:
            llm.llm_engine.add_lora(lora_request=lora_request)

1014
    # Don't want to check the flag multiple times, so just hijack `prompts`.
1015
1016
    prompts = req_data.prompts if args.use_different_prompt_per_request else [
        req_data.prompts[0]
1017
    ]
1018
1019
1020

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
1021
1022
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
1023
                                     stop_token_ids=req_data.stop_token_ids)
1024
1025
1026
1027
1028

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
1029
            "prompt": prompts[0],
1030
            "multi_modal_data": {
1031
                modality: data
1032
1033
1034
1035
            },
        }
    else:
        # Batch inference
1036
1037
1038
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
1039
                                        args.num_prompts, data, prompts,
1040
1041
1042
1043
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
1044
                "prompt": prompts[i % len(prompts)],
1045
1046
1047
                "multi_modal_data": {
                    modality: data
                },
1048
            } for i in range(args.num_prompts)]
1049
1050
1051
1052
1053
1054
1055

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
1056

1057
1058
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
1059
1060
1061
1062
1063
1064
1065
1066
1067

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
1068
        'vision language models for text generation')
1069
1070
1071
1072
1073
1074
1075
1076
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
1077
                        default=4,
1078
                        help='Number of prompts to run.')
1079
1080
1081
    parser.add_argument('--modality',
                        type=str,
                        default="image",
1082
                        choices=['image', 'video'],
1083
1084
1085
1086
1087
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
1088
1089
1090
1091
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help="Set the seed when initializing `vllm.LLM`.")
1092
1093
1094
1095
1096
1097
1098
1099
1100

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
1101
        '--disable-mm-preprocessor-cache',
1102
        action='store_true',
1103
        help='If True, disables caching of multi-modal preprocessor/mapper.')
1104
1105
1106
1107
1108
1109

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

1110
1111
1112
1113
1114
1115
    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

1116
    args = parser.parse_args()
1117
    main(args)