vision_language.py 26.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
Cyrus Leung's avatar
Cyrus Leung committed
3
4
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
5
6
7
8

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
9
10
import random

11
12
13
14
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
15
from vllm.assets.video import VideoAsset
16
17
from vllm.utils import FlexibleArgumentParser

18
19
20
21
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

22

23
# Aria
24
def run_aria(questions: list[str], modality: str):
25
26
27
    assert modality == "image"
    model_name = "rhymes-ai/Aria"

28
    # NOTE: Need L40 (or equivalent) to avoid OOM
29
    llm = LLM(model=model_name,
30
31
              max_model_len=4096,
              max_num_seqs=2,
32
              dtype="bfloat16",
33
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
34

35
36
37
    prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
                "<|im_end|>\n<|im_start|>assistant\n")
               for question in questions]
38
39

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
40
    return llm, prompts, stop_token_ids
41
42
43


# BLIP-2
44
def run_blip2(questions: list[str], modality: str):
45
46
47
48
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
49
    prompts = [f"Question: {question} Answer:" for question in questions]
50
    llm = LLM(model="Salesforce/blip2-opt-2.7b",
51
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
52
    stop_token_ids = None
53
    return llm, prompts, stop_token_ids
54
55
56


# Chameleon
57
def run_chameleon(questions: list[str], modality: str):
58
59
    assert modality == "image"

60
    prompts = [f"{question}<image>" for question in questions]
61
62
    llm = LLM(model="facebook/chameleon-7b",
              max_model_len=4096,
63
              max_num_seqs=2,
64
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
65
    stop_token_ids = None
66
    return llm, prompts, stop_token_ids
67
68


69
# Deepseek-VL2
70
def run_deepseek_vl2(questions: list[str], modality: str):
71
72
    assert modality == "image"

73
    model_name = "deepseek-ai/deepseek-vl2-tiny"
74
75
76
77
78
79
80

    llm = LLM(model=model_name,
              max_model_len=4096,
              max_num_seqs=2,
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
              hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})

81
82
83
84
    prompts = [
        f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
        for question in questions
    ]
85
    stop_token_ids = None
86
    return llm, prompts, stop_token_ids
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Florence2
def run_florence2(question: str, modality: str):
    assert modality == "image"

    llm = LLM(model="microsoft/Florence-2-large",
              tokenizer="facebook/bart-large",
              max_num_seqs=8,
              trust_remote_code=True,
              dtype="bfloat16",
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    prompt = "<MORE_DETAILED_CAPTION>"
    stop_token_ids = None
    return llm, prompt, stop_token_ids


105
# Fuyu
106
def run_fuyu(questions: list[str], modality: str):
107
108
    assert modality == "image"

109
    prompts = [f"{question}\n" for question in questions]
110
111
112
    llm = LLM(model="adept/fuyu-8b",
              max_model_len=2048,
              max_num_seqs=2,
113
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
114
    stop_token_ids = None
115
    return llm, prompts, stop_token_ids
116
117
118


# GLM-4v
119
def run_glm4v(questions: list[str], modality: str):
120
121
122
123
124
125
126
127
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

    llm = LLM(model=model_name,
              max_model_len=2048,
              max_num_seqs=2,
              trust_remote_code=True,
              enforce_eager=True,
128
              hf_overrides={"architectures": ["GLM4VForCausalLM"]},
129
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
130

131
132
133
134
    prompts = [
        f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
        {question}<|assistant|>" for question in questions
    ]
135

136
    stop_token_ids = [151329, 151336, 151338]
137
    return llm, prompts, stop_token_ids
138
139
140


# H2OVL-Mississippi
141
def run_h2ovl(questions: list[str], modality: str):
142
143
    assert modality == "image"

144
    model_name = "h2oai/h2ovl-mississippi-800m"
145
146
147
148
149

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
150
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
151
152
153
154
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
155
156
157
158
159
160
161
162
163
    prompts = [
        tokenizer.apply_chat_template([{
            'role': 'user',
            'content': f"<image>\n{question}"
        }],
                                      tokenize=False,
                                      add_generation_prompt=True)
        for question in questions
    ]
164
165

    # Stop tokens for H2OVL-Mississippi
166
    # https://huggingface.co/h2oai/h2ovl-mississippi-800m
167
    stop_token_ids = [tokenizer.eos_token_id]
168
    return llm, prompts, stop_token_ids
169
170
171


# Idefics3-8B-Llama3
172
def run_idefics3(questions: list[str], modality: str):
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    assert modality == "image"
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

    llm = LLM(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
188
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
189
    )
190
    prompts = [(
191
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
192
    ) for question in questions]
193
    stop_token_ids = None
194
    return llm, prompts, stop_token_ids
195
196
197


# InternVL
198
def run_internvl(questions: list[str], modality: str):
199
200
201
202
203
204
205
206
    assert modality == "image"

    model_name = "OpenGVLab/InternVL2-2B"

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
207
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
208
209
210
211
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
212
213
214
215
216
217
218
219
220
    prompts = [
        tokenizer.apply_chat_template([{
            'role': 'user',
            'content': f"<image>\n{question}"
        }],
                                      tokenize=False,
                                      add_generation_prompt=True)
        for question in questions
    ]
221
222
223
224
225
226
227

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
228
    return llm, prompts, stop_token_ids
229
230


231
# LLaVA-1.5
232
def run_llava(questions: list[str], modality: str):
233
    assert modality == "image"
234

235
236
237
    prompts = [
        f"USER: <image>\n{question}\nASSISTANT:" for question in questions
    ]
238

239
240
    llm = LLM(model="llava-hf/llava-1.5-7b-hf",
              max_model_len=4096,
241
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
242
    stop_token_ids = None
243
    return llm, prompts, stop_token_ids
244
245
246


# LLaVA-1.6/LLaVA-NeXT
247
def run_llava_next(questions: list[str], modality: str):
248
    assert modality == "image"
249

250
    prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
251
252
    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
              max_model_len=8192,
253
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
254
    stop_token_ids = None
255
    return llm, prompts, stop_token_ids
256
257
258
259


# LlaVA-NeXT-Video
# Currently only support for video input
260
def run_llava_next_video(questions: list[str], modality: str):
261
262
    assert modality == "video"

263
264
265
    prompts = [
        f"USER: <video>\n{question} ASSISTANT:" for question in questions
    ]
266
267
    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
              max_model_len=8192,
268
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
269
    stop_token_ids = None
270
    return llm, prompts, stop_token_ids
271
272


273
# LLaVA-OneVision
274
def run_llava_onevision(questions: list[str], modality: str):
275
276

    if modality == "video":
277
278
279
280
        prompts = [
            f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
281
282

    elif modality == "image":
283
284
285
286
        prompts = [
            f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
287
288

    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
289
              max_model_len=16384,
290
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
291
    stop_token_ids = None
292
    return llm, prompts, stop_token_ids
293
294


295
# Mantis
296
def run_mantis(questions: list[str], modality: str):
297
    assert modality == "image"
298

299
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
300
301
302
303
    prompts = [
        llama3_template.format(f"{question}\n<image>")
        for question in questions
    ]
304
305

    llm = LLM(
306
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
307
        max_model_len=4096,
308
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
309
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
310
    )
311
    stop_token_ids = [128009]
312
    return llm, prompts, stop_token_ids
313
314
315


# MiniCPM-V
316
def run_minicpmv_base(questions: list[str], modality: str, model_name):
317
318
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
319
320
321
322
323
324
325

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
326
327
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

328
    # 2.6
329
330
331
332
333
334
335
336
337
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
338
339
340
341
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    llm = LLM(
        model=model_name,
342
343
        max_model_len=4096,
        max_num_seqs=2,
344
        trust_remote_code=True,
345
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
346
    )
347
348
349
350
351
352
353
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

354
    # 2.6 / o2.6
355
356
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
357

358
359
360
361
362
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

363
364
365
366
367
368
369
370
371
372
    prompts = [
        tokenizer.apply_chat_template(
            [{
                'role': 'user',
                'content': f"{modality_placeholder[modality]}\n{question}"
            }],
            tokenize=False,
            add_generation_prompt=True) for question in questions
    ]
    return llm, prompts, stop_token_ids
373
374


375
376
def run_minicpmo(questions: list[str], modality: str):
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
377
378


379
380
def run_minicpmv(questions: list[str], modality: str):
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
381
382


383
# LLama 3.2
384
def run_mllama(questions: list[str], modality: str):
385
386
    assert modality == "image"

387
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
388

389
390
391
392
393
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
394
395
    llm = LLM(
        model=model_name,
396
397
        max_model_len=4096,
        max_num_seqs=16,
398
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
399
400
    )

401
402
403
404
405
406
407
408
409
410
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    messages = [{
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
            "text": f"{question}"
        }]
411
412
413
414
    } for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            add_generation_prompt=True,
                                            tokenize=False)
415
    stop_token_ids = None
416
    return llm, prompts, stop_token_ids
417
418


419
# Molmo
420
def run_molmo(questions: list[str], modality: str):
421
422
    assert modality == "image"

423
    model_name = "allenai/Molmo-7B-D-0924"
424

425
    llm = LLM(
426
        model=model_name,
427
        trust_remote_code=True,
428
        dtype="bfloat16",
429
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
430
    )
431

432
433
434
435
    prompts = [
        f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
    ]
436
    stop_token_ids = None
437
    return llm, prompts, stop_token_ids
438
439


440
# NVLM-D
441
def run_nvlm_d(questions: list[str], modality: str):
442
443
444
445
446
447
448
449
450
451
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
452
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
453
454
455
456
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
457
458
459
460
461
462
463
    messages = [{
        'role': 'user',
        'content': f"<image>\n{question}"
    } for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
464
    stop_token_ids = None
465
    return llm, prompts, stop_token_ids
466
467


468
469
# PaliGemma
def run_paligemma(question: str, modality: str):
470
    assert modality == "image"
471

472
    # PaliGemma has special prompt format for VQA
473
    prompt = ["caption en"]
474
    llm = LLM(model="google/paligemma-3b-mix-224",
475
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
476
477
    stop_token_ids = None
    return llm, prompt, stop_token_ids
478
479


480
481
# PaliGemma 2
def run_paligemma2(question: str, modality: str):
482
    assert modality == "image"
483

484
    # PaliGemma 2 has special prompt format for VQA
485
    prompt = ["caption en"]
486
    llm = LLM(model="google/paligemma2-3b-ft-docci-448",
487
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
488
489
490
491
    stop_token_ids = None
    return llm, prompt, stop_token_ids


492
# Phi-3-Vision
493
def run_phi3v(questions: list[str], modality: str):
494
495
    assert modality == "image"

496
497
498
499
    prompts = [
        f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
        for question in questions
    ]
500

501
502
503
504
505
506
507
508
509
510
511
512
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
513
    llm = LLM(
514
515
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
516
        max_model_len=4096,
517
        max_num_seqs=2,
518
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
519
        mm_processor_kwargs={"num_crops": 16},
520
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
521
522
    )
    stop_token_ids = None
523
    return llm, prompts, stop_token_ids
524
525


526
# Pixtral HF-format
527
def run_pixtral_hf(questions: list[str], modality: str):
528
529
530
531
    assert modality == "image"

    model_name = "mistral-community/pixtral-12b"

532
    # NOTE: Need L40 (or equivalent) to avoid OOM
533
534
535
    llm = LLM(
        model=model_name,
        max_model_len=8192,
536
        max_num_seqs=2,
537
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
538
539
    )

540
    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
541
    stop_token_ids = None
542
    return llm, prompts, stop_token_ids
543
544


545
# Qwen
546
def run_qwen_vl(questions: list[str], modality: str):
547
548
549
    assert modality == "image"

    llm = LLM(
550
        model="Qwen/Qwen-VL",
551
        trust_remote_code=True,
552
553
        max_model_len=1024,
        max_num_seqs=2,
554
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
555
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
556
557
    )

558
    prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
559
    stop_token_ids = None
560
    return llm, prompts, stop_token_ids
561
562


563
# Qwen2-VL
564
def run_qwen2_vl(questions: list[str], modality: str):
565

566
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
567

568
569
    llm = LLM(
        model=model_name,
570
571
572
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
573
        mm_processor_kwargs={
574
575
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
576
        },
577
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
578
    )
579

580
581
582
583
584
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

585
586
587
588
589
590
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
591
    stop_token_ids = None
592
    return llm, prompts, stop_token_ids
593
594


Roger Wang's avatar
Roger Wang committed
595
# Qwen2.5-VL
596
def run_qwen2_5_vl(questions: list[str], modality: str):
Roger Wang's avatar
Roger Wang committed
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

    llm = LLM(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

617
618
619
620
621
622
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
Roger Wang's avatar
Roger Wang committed
623
    stop_token_ids = None
624
    return llm, prompts, stop_token_ids
Roger Wang's avatar
Roger Wang committed
625
626


627
model_example_map = {
628
629
630
    "aria": run_aria,
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
631
    "deepseek_vl_v2": run_deepseek_vl2,
632
    "florence2": run_florence2,
633
634
635
636
637
    "fuyu": run_fuyu,
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
638
639
    "llava": run_llava,
    "llava-next": run_llava_next,
640
    "llava-next-video": run_llava_next_video,
641
    "llava-onevision": run_llava_onevision,
642
    "mantis": run_mantis,
643
    "minicpmo": run_minicpmo,
644
    "minicpmv": run_minicpmv,
645
646
    "mllama": run_mllama,
    "molmo": run_molmo,
647
    "NVLM_D": run_nvlm_d,
648
649
650
651
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
    "pixtral_hf": run_pixtral_hf,
652
    "qwen_vl": run_qwen_vl,
653
    "qwen2_vl": run_qwen2_vl,
Roger Wang's avatar
Roger Wang committed
654
    "qwen2_5_vl": run_qwen2_5_vl,
655
656
657
}


658
659
660
661
662
663
664
665
666
667
668
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
669
670
671
672
673
674
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "What's in the image?",
            "Where is this image taken?",
        ]
675
676
677

        return {
            "data": image,
678
            "questions": img_questions,
679
680
681
682
683
684
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
685
        vid_questions = ["Why is this video funny?"]
686
687
688

        return {
            "data": video,
689
            "questions": vid_questions,
690
691
692
693
694
695
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


696
697
def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
717
            "prompt": prompts[i % len(prompts)],
718
719
720
721
722
723
724
725
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


726
727
728
729
730
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

731
732
733
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
734
    questions = mm_input["questions"]
735

736
737
738
739
740
741
    llm, prompts, stop_token_ids = model_example_map[model](questions,
                                                            modality)
    # Don't want to check the flag multiple times, so just hijack `prompts`.
    prompts = prompts if args.use_different_prompt_per_request else [
        prompts[0]
    ]
742
743
744

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
745
746
747
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
                                     stop_token_ids=stop_token_ids)
748
749
750
751
752

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
753
            "prompt": prompts[0],
754
            "multi_modal_data": {
755
                modality: data
756
757
758
759
            },
        }
    else:
        # Batch inference
760
761
762
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
763
                                        args.num_prompts, data, prompts,
764
765
766
767
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
768
                "prompt": prompts[i % len(prompts)],
769
770
771
                "multi_modal_data": {
                    modality: data
                },
772
            } for i in range(args.num_prompts)]
773
774
775
776
777
778
779

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
780

781
782
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
783
784
785
786
787
788
789
790
791

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
792
        'vision language models for text generation')
793
794
795
796
797
798
799
800
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
801
                        default=4,
802
                        help='Number of prompts to run.')
803
804
805
    parser.add_argument('--modality',
                        type=str,
                        default="image",
806
                        choices=['image', 'video'],
807
808
809
810
811
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
812
813
814
815
816
817
818
819
820

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
821
        '--disable-mm-preprocessor-cache',
822
        action='store_true',
823
        help='If True, disables caching of multi-modal preprocessor/mapper.')
824
825
826
827
828
829

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

830
831
832
833
834
835
    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

836
    args = parser.parse_args()
837
    main(args)