vision_language.py 26.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
Cyrus Leung's avatar
Cyrus Leung committed
3
4
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
5
6
7
8

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
9
10
import random

11
12
13
14
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
15
from vllm.assets.video import VideoAsset
16
17
from vllm.utils import FlexibleArgumentParser

18
19
20
21
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

22

23
# Aria
24
def run_aria(questions: list[str], modality: str):
25
26
27
    assert modality == "image"
    model_name = "rhymes-ai/Aria"

28
    # NOTE: Need L40 (or equivalent) to avoid OOM
29
    llm = LLM(model=model_name,
30
31
              max_model_len=4096,
              max_num_seqs=2,
32
              dtype="bfloat16",
33
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
34

35
36
37
    prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
                "<|im_end|>\n<|im_start|>assistant\n")
               for question in questions]
38
39

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
40
    return llm, prompts, stop_token_ids
41
42
43


# BLIP-2
44
def run_blip2(questions: list[str], modality: str):
45
46
47
48
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
49
    prompts = [f"Question: {question} Answer:" for question in questions]
50
    llm = LLM(model="Salesforce/blip2-opt-2.7b",
51
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
52
    stop_token_ids = None
53
    return llm, prompts, stop_token_ids
54
55
56


# Chameleon
57
def run_chameleon(questions: list[str], modality: str):
58
59
    assert modality == "image"

60
    prompts = [f"{question}<image>" for question in questions]
61
62
    llm = LLM(model="facebook/chameleon-7b",
              max_model_len=4096,
63
              max_num_seqs=2,
64
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
65
    stop_token_ids = None
66
    return llm, prompts, stop_token_ids
67
68


69
# Deepseek-VL2
70
def run_deepseek_vl2(questions: list[str], modality: str):
71
72
    assert modality == "image"

73
    model_name = "deepseek-ai/deepseek-vl2-tiny"
74
75
76
77
78
79
80

    llm = LLM(model=model_name,
              max_model_len=4096,
              max_num_seqs=2,
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
              hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})

81
82
83
84
    prompts = [
        f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
        for question in questions
    ]
85
    stop_token_ids = None
86
    return llm, prompts, stop_token_ids
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Florence2
def run_florence2(question: str, modality: str):
    assert modality == "image"

    llm = LLM(model="microsoft/Florence-2-large",
              tokenizer="facebook/bart-large",
              max_num_seqs=8,
              trust_remote_code=True,
              dtype="bfloat16",
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

    prompt = "<MORE_DETAILED_CAPTION>"
    stop_token_ids = None
    return llm, prompt, stop_token_ids


105
# Fuyu
106
def run_fuyu(questions: list[str], modality: str):
107
108
    assert modality == "image"

109
    prompts = [f"{question}\n" for question in questions]
110
111
112
    llm = LLM(model="adept/fuyu-8b",
              max_model_len=2048,
              max_num_seqs=2,
113
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
114
    stop_token_ids = None
115
    return llm, prompts, stop_token_ids
116
117
118


# GLM-4v
119
def run_glm4v(questions: list[str], modality: str):
120
121
122
123
124
125
126
127
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

    llm = LLM(model=model_name,
              max_model_len=2048,
              max_num_seqs=2,
              trust_remote_code=True,
              enforce_eager=True,
128
              hf_overrides={"architectures": ["GLM4VForCausalLM"]},
129
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
130

131
132
133
134
    prompts = [
        f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
        {question}<|assistant|>" for question in questions
    ]
135

136
    stop_token_ids = [151329, 151336, 151338]
137
    return llm, prompts, stop_token_ids
138
139
140


# H2OVL-Mississippi
141
def run_h2ovl(questions: list[str], modality: str):
142
143
    assert modality == "image"

144
    model_name = "h2oai/h2ovl-mississippi-800m"
145
146
147
148
149

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
150
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
151
152
153
154
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
155
156
157
158
159
160
161
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
162
163

    # Stop tokens for H2OVL-Mississippi
164
    # https://huggingface.co/h2oai/h2ovl-mississippi-800m
165
    stop_token_ids = [tokenizer.eos_token_id]
166
    return llm, prompts, stop_token_ids
167
168
169


# Idefics3-8B-Llama3
170
def run_idefics3(questions: list[str], modality: str):
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    assert modality == "image"
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"

    llm = LLM(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
186
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
187
    )
188
    prompts = [(
189
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
190
    ) for question in questions]
191
    stop_token_ids = None
192
    return llm, prompts, stop_token_ids
193
194
195


# InternVL
196
def run_internvl(questions: list[str], modality: str):
197
198
199
200
201
202
203
204
    assert modality == "image"

    model_name = "OpenGVLab/InternVL2-2B"

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
205
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
206
207
208
209
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
210
211
212
213
214
215
216
    messages = [[{
        'role': 'user',
        'content': f"<image>\n{question}"
    }] for question in questions]
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
217
218
219
220
221
222
223

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
224
    return llm, prompts, stop_token_ids
225
226


227
# LLaVA-1.5
228
def run_llava(questions: list[str], modality: str):
229
    assert modality == "image"
230

231
232
233
    prompts = [
        f"USER: <image>\n{question}\nASSISTANT:" for question in questions
    ]
234

235
236
    llm = LLM(model="llava-hf/llava-1.5-7b-hf",
              max_model_len=4096,
237
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
238
    stop_token_ids = None
239
    return llm, prompts, stop_token_ids
240
241
242


# LLaVA-1.6/LLaVA-NeXT
243
def run_llava_next(questions: list[str], modality: str):
244
    assert modality == "image"
245

246
    prompts = [f"[INST] <image>\n{question} [/INST]" for question in questions]
247
248
    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
              max_model_len=8192,
249
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
250
    stop_token_ids = None
251
    return llm, prompts, stop_token_ids
252
253
254
255


# LlaVA-NeXT-Video
# Currently only support for video input
256
def run_llava_next_video(questions: list[str], modality: str):
257
258
    assert modality == "video"

259
260
261
    prompts = [
        f"USER: <video>\n{question} ASSISTANT:" for question in questions
    ]
262
263
    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
              max_model_len=8192,
264
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
265
    stop_token_ids = None
266
    return llm, prompts, stop_token_ids
267
268


269
# LLaVA-OneVision
270
def run_llava_onevision(questions: list[str], modality: str):
271
272

    if modality == "video":
273
274
275
276
        prompts = [
            f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
277
278

    elif modality == "image":
279
280
281
282
        prompts = [
            f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
        ]
283
284

    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
285
              max_model_len=16384,
286
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
287
    stop_token_ids = None
288
    return llm, prompts, stop_token_ids
289
290


291
# Mantis
292
def run_mantis(questions: list[str], modality: str):
293
    assert modality == "image"
294

295
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
296
297
298
299
    prompts = [
        llama3_template.format(f"{question}\n<image>")
        for question in questions
    ]
300
301

    llm = LLM(
302
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
303
        max_model_len=4096,
304
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
305
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
306
    )
307
    stop_token_ids = [128009]
308
    return llm, prompts, stop_token_ids
309
310
311


# MiniCPM-V
312
def run_minicpmv_base(questions: list[str], modality: str, model_name):
313
314
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
315
316
317
318
319
320
321

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
322
323
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

324
    # 2.6
325
326
327
328
329
330
331
332
333
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
334
335
336
337
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    llm = LLM(
        model=model_name,
338
339
        max_model_len=4096,
        max_num_seqs=2,
340
        trust_remote_code=True,
341
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
342
    )
343
344
345
346
347
348
349
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

350
    # 2.6 / o2.6
351
352
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
353

354
355
356
357
358
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

359
360
361
362
363
364
365
366
367
368
    prompts = [
        tokenizer.apply_chat_template(
            [{
                'role': 'user',
                'content': f"{modality_placeholder[modality]}\n{question}"
            }],
            tokenize=False,
            add_generation_prompt=True) for question in questions
    ]
    return llm, prompts, stop_token_ids
369
370


371
372
def run_minicpmo(questions: list[str], modality: str):
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-o-2_6")
373
374


375
376
def run_minicpmv(questions: list[str], modality: str):
    return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
377
378


379
# LLama 3.2
380
def run_mllama(questions: list[str], modality: str):
381
382
    assert modality == "image"

383
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
384

385
386
387
388
389
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
390
391
    llm = LLM(
        model=model_name,
392
393
        max_model_len=4096,
        max_num_seqs=16,
394
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
395
396
    )

397
    tokenizer = AutoTokenizer.from_pretrained(model_name)
398
    messages = [[{
399
400
401
402
403
404
405
406
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
            "text": f"{question}"
        }]
407
    }] for question in questions]
408
409
410
    prompts = tokenizer.apply_chat_template(messages,
                                            add_generation_prompt=True,
                                            tokenize=False)
411
    stop_token_ids = None
412
    return llm, prompts, stop_token_ids
413
414


415
# Molmo
416
def run_molmo(questions: list[str], modality: str):
417
418
    assert modality == "image"

419
    model_name = "allenai/Molmo-7B-D-0924"
420

421
    llm = LLM(
422
        model=model_name,
423
        trust_remote_code=True,
424
        dtype="bfloat16",
425
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
426
    )
427

428
429
430
431
    prompts = [
        f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n" for question in questions
    ]
432
    stop_token_ids = None
433
    return llm, prompts, stop_token_ids
434
435


436
# NVLM-D
437
def run_nvlm_d(questions: list[str], modality: str):
438
439
440
441
442
443
444
445
446
447
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
448
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
449
450
451
452
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
453
    messages = [[{
454
455
        'role': 'user',
        'content': f"<image>\n{question}"
456
    }] for question in questions]
457
458
459
    prompts = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
460
    stop_token_ids = None
461
    return llm, prompts, stop_token_ids
462
463


464
465
# PaliGemma
def run_paligemma(question: str, modality: str):
466
    assert modality == "image"
467

468
    # PaliGemma has special prompt format for VQA
469
    prompt = ["caption en"]
470
    llm = LLM(model="google/paligemma-3b-mix-224",
471
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
472
473
    stop_token_ids = None
    return llm, prompt, stop_token_ids
474
475


476
477
# PaliGemma 2
def run_paligemma2(question: str, modality: str):
478
    assert modality == "image"
479

480
    # PaliGemma 2 has special prompt format for VQA
481
    prompt = ["caption en"]
482
    llm = LLM(model="google/paligemma2-3b-ft-docci-448",
483
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
484
485
486
487
    stop_token_ids = None
    return llm, prompt, stop_token_ids


488
# Phi-3-Vision
489
def run_phi3v(questions: list[str], modality: str):
490
491
    assert modality == "image"

492
493
494
495
    prompts = [
        f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
        for question in questions
    ]
496

497
498
499
500
501
502
503
504
505
506
507
508
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
509
    llm = LLM(
510
511
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
512
        max_model_len=4096,
513
        max_num_seqs=2,
514
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
515
        mm_processor_kwargs={"num_crops": 16},
516
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
517
518
    )
    stop_token_ids = None
519
    return llm, prompts, stop_token_ids
520
521


522
# Pixtral HF-format
523
def run_pixtral_hf(questions: list[str], modality: str):
524
525
526
527
    assert modality == "image"

    model_name = "mistral-community/pixtral-12b"

528
    # NOTE: Need L40 (or equivalent) to avoid OOM
529
530
531
    llm = LLM(
        model=model_name,
        max_model_len=8192,
532
        max_num_seqs=2,
533
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
534
535
    )

536
    prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
537
    stop_token_ids = None
538
    return llm, prompts, stop_token_ids
539
540


541
# Qwen
542
def run_qwen_vl(questions: list[str], modality: str):
543
544
545
    assert modality == "image"

    llm = LLM(
546
        model="Qwen/Qwen-VL",
547
        trust_remote_code=True,
548
549
        max_model_len=1024,
        max_num_seqs=2,
550
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
551
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
552
553
    )

554
    prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
555
    stop_token_ids = None
556
    return llm, prompts, stop_token_ids
557
558


559
# Qwen2-VL
560
def run_qwen2_vl(questions: list[str], modality: str):
561

562
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
563

564
565
    llm = LLM(
        model=model_name,
566
567
568
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
569
        mm_processor_kwargs={
570
571
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
572
        },
573
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
574
    )
575

576
577
578
579
580
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

581
582
583
584
585
586
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
587
    stop_token_ids = None
588
    return llm, prompts, stop_token_ids
589
590


Roger Wang's avatar
Roger Wang committed
591
# Qwen2.5-VL
592
def run_qwen2_5_vl(questions: list[str], modality: str):
Roger Wang's avatar
Roger Wang committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

    llm = LLM(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1,
        },
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

613
614
615
616
617
618
    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
Roger Wang's avatar
Roger Wang committed
619
    stop_token_ids = None
620
    return llm, prompts, stop_token_ids
Roger Wang's avatar
Roger Wang committed
621
622


623
model_example_map = {
624
625
626
    "aria": run_aria,
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
627
    "deepseek_vl_v2": run_deepseek_vl2,
628
    "florence2": run_florence2,
629
630
631
632
633
    "fuyu": run_fuyu,
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
634
635
    "llava": run_llava,
    "llava-next": run_llava_next,
636
    "llava-next-video": run_llava_next_video,
637
    "llava-onevision": run_llava_onevision,
638
    "mantis": run_mantis,
639
    "minicpmo": run_minicpmo,
640
    "minicpmv": run_minicpmv,
641
642
    "mllama": run_mllama,
    "molmo": run_molmo,
643
    "NVLM_D": run_nvlm_d,
644
645
646
647
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
    "pixtral_hf": run_pixtral_hf,
648
    "qwen_vl": run_qwen_vl,
649
    "qwen2_vl": run_qwen2_vl,
Roger Wang's avatar
Roger Wang committed
650
    "qwen2_5_vl": run_qwen2_5_vl,
651
652
653
}


654
655
656
657
658
659
660
661
662
663
664
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
665
666
667
668
669
670
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "What's in the image?",
            "Where is this image taken?",
        ]
671
672
673

        return {
            "data": image,
674
            "questions": img_questions,
675
676
677
678
679
680
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
681
        vid_questions = ["Why is this video funny?"]
682
683
684

        return {
            "data": video,
685
            "questions": vid_questions,
686
687
688
689
690
691
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


692
693
def apply_image_repeat(image_repeat_prob, num_prompts, data,
                       prompts: list[str], modality):
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
713
            "prompt": prompts[i % len(prompts)],
714
715
716
717
718
719
720
721
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


722
723
724
725
726
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

727
728
729
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
730
    questions = mm_input["questions"]
731

732
733
734
735
736
737
    llm, prompts, stop_token_ids = model_example_map[model](questions,
                                                            modality)
    # Don't want to check the flag multiple times, so just hijack `prompts`.
    prompts = prompts if args.use_different_prompt_per_request else [
        prompts[0]
    ]
738
739
740

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
741
742
743
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
                                     stop_token_ids=stop_token_ids)
744
745
746
747
748

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
749
            "prompt": prompts[0],
750
            "multi_modal_data": {
751
                modality: data
752
753
754
755
            },
        }
    else:
        # Batch inference
756
757
758
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
759
                                        args.num_prompts, data, prompts,
760
761
762
763
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
764
                "prompt": prompts[i % len(prompts)],
765
766
767
                "multi_modal_data": {
                    modality: data
                },
768
            } for i in range(args.num_prompts)]
769
770
771
772
773
774
775

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
776

777
778
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
779
780
781
782
783
784
785
786
787

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
788
        'vision language models for text generation')
789
790
791
792
793
794
795
796
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
797
                        default=4,
798
                        help='Number of prompts to run.')
799
800
801
    parser.add_argument('--modality',
                        type=str,
                        default="image",
802
                        choices=['image', 'video'],
803
804
805
806
807
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
808
809
810
811
812
813
814
815
816

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
817
        '--disable-mm-preprocessor-cache',
818
        action='store_true',
819
        help='If True, disables caching of multi-modal preprocessor/mapper.')
820
821
822
823
824
825

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

826
827
828
829
830
831
    parser.add_argument(
        '--use-different-prompt-per-request',
        action='store_true',
        help='If True, then use different prompt (with the same multi-modal '
        'data) for each request.')

832
    args = parser.parse_args()
833
    main(args)