vision_language.py 23 KB
Newer Older
1
"""
Cyrus Leung's avatar
Cyrus Leung committed
2
3
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for text generation.
4
5
6
7

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
8
9
import random

10
11
12
13
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
14
from vllm.assets.video import VideoAsset
15
16
from vllm.utils import FlexibleArgumentParser

17
18
19
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
20
21


22
23
# Aria
def run_aria(question: str, modality: str):
24
    assert modality == "image"
25
    model_name = "rhymes-ai/Aria"
26

27
    # NOTE: Need L40 (or equivalent) to avoid OOM
28
    llm = LLM(model=model_name,
29
30
              max_model_len=4096,
              max_num_seqs=2,
31
              dtype="bfloat16",
32
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
33

34
    prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
35
36
37
38
              "<|im_end|>\n<|im_start|>assistant\n")

    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
    return llm, prompt, stop_token_ids
39

40
41
42
43
44
45
46
47
48

# BLIP-2
def run_blip2(question: str, modality: str):
    assert modality == "image"

    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
    prompt = f"Question: {question} Answer:"
    llm = LLM(model="Salesforce/blip2-opt-2.7b",
49
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
50
51
    stop_token_ids = None
    return llm, prompt, stop_token_ids
52
53


54
55
# Chameleon
def run_chameleon(question: str, modality: str):
56
    assert modality == "image"
57

58
59
60
    prompt = f"{question}<image>"
    llm = LLM(model="facebook/chameleon-7b",
              max_model_len=4096,
61
              max_num_seqs=2,
62
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
63
64
65
66
    stop_token_ids = None
    return llm, prompt, stop_token_ids


67
68
69
70
# Deepseek-VL2
def run_deepseek_vl2(question: str, modality: str):
    assert modality == "image"

71
    model_name = "deepseek-ai/deepseek-vl2-tiny"
72
73
74
75
76
77
78
79
80
81
82
83

    llm = LLM(model=model_name,
              max_model_len=4096,
              max_num_seqs=2,
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
              hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]})

    prompt = f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
    stop_token_ids = None
    return llm, prompt, stop_token_ids


84
85
86
# Fuyu
def run_fuyu(question: str, modality: str):
    assert modality == "image"
87

88
89
90
91
    prompt = f"{question}\n"
    llm = LLM(model="adept/fuyu-8b",
              max_model_len=2048,
              max_num_seqs=2,
92
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
93
94
    stop_token_ids = None
    return llm, prompt, stop_token_ids
95
96


97
98
99
100
# GLM-4v
def run_glm4v(question: str, modality: str):
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"
101

102
103
104
105
106
    llm = LLM(model=model_name,
              max_model_len=2048,
              max_num_seqs=2,
              trust_remote_code=True,
              enforce_eager=True,
107
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
108
109
110
    prompt = question
    stop_token_ids = [151329, 151336, 151338]
    return llm, prompt, stop_token_ids
111
112


113
114
115
116
117
118
119
120
121
122
# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
    assert modality == "image"

    model_name = "h2oai/h2ovl-mississippi-2b"

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=8192,
123
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
124
125
126
127
128
129
130
131
132
133
134
135
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

    # Stop tokens for H2OVL-Mississippi
    # https://huggingface.co/h2oai/h2ovl-mississippi-2b
    stop_token_ids = [tokenizer.eos_token_id]
136
137
138
    return llm, prompt, stop_token_ids


139
140
# Idefics3-8B-Llama3
def run_idefics3(question: str, modality: str):
141
    assert modality == "image"
142
    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
143

144
145
146
147
148
149
150
151
152
153
154
155
    llm = LLM(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        enforce_eager=True,
        # if you are running out of memory, you can reduce the "longest_edge".
        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
        mm_processor_kwargs={
            "size": {
                "longest_edge": 3 * 364
            },
        },
156
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
157
158
159
160
    )
    prompt = (
        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
    )
161
162
    stop_token_ids = None
    return llm, prompt, stop_token_ids
163
164


165
166
# InternVL
def run_internvl(question: str, modality: str):
167
    assert modality == "image"
168

169
    model_name = "OpenGVLab/InternVL2-2B"
170

171
    llm = LLM(
172
        model=model_name,
173
        trust_remote_code=True,
174
        max_model_len=4096,
175
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
176
    )
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
191
    return llm, prompt, stop_token_ids
192
193
194


# LLaVA-1.5
195
def run_llava(question: str, modality: str):
196
    assert modality == "image"
197
198
199

    prompt = f"USER: <image>\n{question}\nASSISTANT:"

200
201
    llm = LLM(model="llava-hf/llava-1.5-7b-hf",
              max_model_len=4096,
202
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
203
204
    stop_token_ids = None
    return llm, prompt, stop_token_ids
205
206
207


# LLaVA-1.6/LLaVA-NeXT
208
def run_llava_next(question: str, modality: str):
209
    assert modality == "image"
210
211

    prompt = f"[INST] <image>\n{question} [/INST]"
212
213
    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
              max_model_len=8192,
214
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
215
216
217
218
219
220
    stop_token_ids = None
    return llm, prompt, stop_token_ids


# LlaVA-NeXT-Video
# Currently only support for video input
221
def run_llava_next_video(question: str, modality: str):
222
223
    assert modality == "video"

224
    prompt = f"USER: <video>\n{question} ASSISTANT:"
225
226
    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
              max_model_len=8192,
227
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
228
229
    stop_token_ids = None
    return llm, prompt, stop_token_ids
230
231


232
# LLaVA-OneVision
233
def run_llava_onevision(question: str, modality: str):
234
235
236
237
238
239
240
241
242
243

    if modality == "video":
        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
        <|im_start|>assistant\n"

    elif modality == "image":
        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
        <|im_start|>assistant\n"

    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
244
              max_model_len=16384,
245
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
246
247
248
249
    stop_token_ids = None
    return llm, prompt, stop_token_ids


250
251
# Mantis
def run_mantis(question: str, modality: str):
252
    assert modality == "image"
253

254
255
    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
    prompt = llama3_template.format(f"{question}\n<image>")
256
257

    llm = LLM(
258
        model="TIGER-Lab/Mantis-8B-siglip-llama3",
259
        max_model_len=4096,
260
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
261
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
262
    )
263
    stop_token_ids = [128009]
264
    return llm, prompt, stop_token_ids
265
266
267


# MiniCPM-V
268
269
270
def run_minicpmv_base(question: str, modality: str, model_name):
    assert modality in ["image", "video"]
    # If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
271
272
273
274
275
276
277

    # 2.0
    # The official repo doesn't work yet, so we need to use a fork for now
    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
    # model_name = "HwwwH/MiniCPM-V-2"

    # 2.5
278
279
    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"

280
    # 2.6
281
282
283
284
285
286
287
288
289
    # model_name = "openbmb/MiniCPM-V-2_6"
    # o2.6

    # modality supports
    # 2.0: image
    # 2.5: image
    # 2.6: image, video
    # o2.6: image, video, audio
    # model_name = "openbmb/MiniCPM-o-2_6"
290
291
292
293
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    llm = LLM(
        model=model_name,
294
295
        max_model_len=4096,
        max_num_seqs=2,
296
        trust_remote_code=True,
297
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
298
    )
299
300
301
302
303
304
305
    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
    # 2.0
    # stop_token_ids = [tokenizer.eos_id]

    # 2.5
    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

306
    # 2.6 / o2.6
307
308
    stop_tokens = ['<|im_end|>', '<|endoftext|>']
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
309

310
311
312
313
314
    modality_placeholder = {
        "image": "(<image>./</image>)",
        "video": "(<video>./</video>)",
    }

315
316
    messages = [{
        'role': 'user',
317
        'content': f'{modality_placeholder[modality]}\n{question}'
318
319
320
321
    }]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)
322
    return llm, prompt, stop_token_ids
323
324


325
326
327
328
329
330
331
332
def run_minicpmo(question: str, modality: str):
    return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6")


def run_minicpmv(question: str, modality: str):
    return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6")


333
334
# LLama 3.2
def run_mllama(question: str, modality: str):
335
336
    assert modality == "image"

337
    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
338

339
340
341
342
343
    # Note: The default setting of max_num_seqs (256) and
    # max_model_len (131072) for this model may cause OOM.
    # You may lower either to run this example on lower-end GPUs.

    # The configuration below has been confirmed to launch on a single L40 GPU.
344
345
    llm = LLM(
        model=model_name,
346
347
        max_model_len=4096,
        max_num_seqs=16,
348
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
349
350
    )

351
352
353
354
355
356
357
358
359
360
361
362
363
364
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    messages = [{
        "role":
        "user",
        "content": [{
            "type": "image"
        }, {
            "type": "text",
            "text": f"{question}"
        }]
    }]
    prompt = tokenizer.apply_chat_template(messages,
                                           add_generation_prompt=True,
                                           tokenize=False)
365
    stop_token_ids = None
366
367
368
    return llm, prompt, stop_token_ids


369
370
# Molmo
def run_molmo(question, modality):
371
372
    assert modality == "image"

373
    model_name = "allenai/Molmo-7B-D-0924"
374

375
    llm = LLM(
376
        model=model_name,
377
        trust_remote_code=True,
378
        dtype="bfloat16",
379
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
380
    )
381

382
383
    prompt = question
    stop_token_ids = None
384
    return llm, prompt, stop_token_ids
385
386


387
388
389
390
391
392
393
394
395
396
397
398
# NVLM-D
def run_nvlm_d(question: str, modality: str):
    assert modality == "image"

    model_name = "nvidia/NVLM-D-72B"

    # Adjust this as necessary to fit in GPU
    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        tensor_parallel_size=4,
399
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
400
    )
401
402
403
404
405
406
407

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)
408
409
    stop_token_ids = None
    return llm, prompt, stop_token_ids
410

411

412
413
# PaliGemma
def run_paligemma(question: str, modality: str):
414
    assert modality == "image"
415

416
417
418
    # PaliGemma has special prompt format for VQA
    prompt = "caption en"
    llm = LLM(model="google/paligemma-3b-mix-224",
419
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
420
421
    stop_token_ids = None
    return llm, prompt, stop_token_ids
422
423


424
425
# PaliGemma 2
def run_paligemma2(question: str, modality: str):
426
    assert modality == "image"
427

428
429
430
    # PaliGemma 2 has special prompt format for VQA
    prompt = "caption en"
    llm = LLM(model="google/paligemma2-3b-ft-docci-448",
431
              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
432
433
    stop_token_ids = None
    return llm, prompt, stop_token_ids
434
435


436
437
# Phi-3-Vision
def run_phi3v(question: str, modality: str):
438
    assert modality == "image"
439

440
    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
441

442
443
444
445
446
447
448
449
450
451
452
453
    # num_crops is an override kwarg to the multimodal image processor;
    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
    # to use 16 for single frame scenarios, and 4 for multi-frame.
    #
    # Generally speaking, a larger value for num_crops results in more
    # tokens per image instance, because it may scale the image more in
    # the image preprocessing. Some references in the model docs and the
    # formula for image tokens after the preprocessing
    # transform can be found below.
    #
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
454
    llm = LLM(
455
        model="microsoft/Phi-3.5-vision-instruct",
456
        trust_remote_code=True,
457
        max_model_len=4096,
458
        max_num_seqs=2,
459
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
460
        mm_processor_kwargs={"num_crops": 16},
461
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
462
463
464
465
466
    )
    stop_token_ids = None
    return llm, prompt, stop_token_ids


467
468
# Pixtral HF-format
def run_pixtral_hf(question: str, modality: str):
469
470
    assert modality == "image"

471
    model_name = "mistral-community/pixtral-12b"
472

473
    # NOTE: Need L40 (or equivalent) to avoid OOM
474
475
    llm = LLM(
        model=model_name,
476
        max_model_len=8192,
477
        max_num_seqs=2,
478
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
479
480
    )

481
    prompt = f"<s>[INST]{question}\n[IMG][/INST]"
482
483
484
485
    stop_token_ids = None
    return llm, prompt, stop_token_ids


486
487
# Qwen
def run_qwen_vl(question: str, modality: str):
488
489
    assert modality == "image"

490
    llm = LLM(
491
        model="Qwen/Qwen-VL",
492
        trust_remote_code=True,
493
494
        max_model_len=1024,
        max_num_seqs=2,
495
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
496
497
    )

498
    prompt = f"{question}Picture 1: <img></img>\n"
499
500
    stop_token_ids = None
    return llm, prompt, stop_token_ids
501
502


503
504
# Qwen2-VL
def run_qwen2_vl(question: str, modality: str):
505

506
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
507

508
509
    llm = LLM(
        model=model_name,
510
511
512
        max_model_len=4096,
        max_num_seqs=5,
        # Note - mm_processor_kwargs can also be passed to generate/chat calls
513
        mm_processor_kwargs={
514
515
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
516
        },
517
        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
518
519
    )

520
521
522
523
524
    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"

525
    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
526
              f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
527
528
              f"{question}<|im_end|>\n"
              "<|im_start|>assistant\n")
529
530
531
532
    stop_token_ids = None
    return llm, prompt, stop_token_ids


zhuwenwen's avatar
zhuwenwen committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
# GLM-4v
def run_glm4v(question: str, modality: str):
    assert modality == "image"
    model_name = "THUDM/glm-4v-9b"

    llm = LLM(model=model_name,
              max_model_len=2048,
              max_num_seqs=2,
              trust_remote_code=True,
              enforce_eager=True)
    prompt = question
    stop_token_ids = [151329, 151336, 151338]
    return llm, prompt, stop_token_ids


548
model_example_map = {
549
550
551
    "aria": run_aria,
    "blip-2": run_blip2,
    "chameleon": run_chameleon,
552
    "deepseek_vl_v2": run_deepseek_vl2,
553
554
555
556
557
    "fuyu": run_fuyu,
    "glm4v": run_glm4v,
    "h2ovl_chat": run_h2ovl,
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
558
559
    "llava": run_llava,
    "llava-next": run_llava_next,
560
    "llava-next-video": run_llava_next_video,
561
    "llava-onevision": run_llava_onevision,
562
    "mantis": run_mantis,
563
    "minicpmo": run_minicpmo,
564
    "minicpmv": run_minicpmv,
565
566
    "mllama": run_mllama,
    "molmo": run_molmo,
567
    "NVLM_D": run_nvlm_d,
568
569
570
571
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "phi3_v": run_phi3v,
    "pixtral_hf": run_pixtral_hf,
572
    "qwen_vl": run_qwen_vl,
573
    "qwen2_vl": run_qwen2_vl,
574
575
576
}


577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def get_multi_modal_input(args):
    """
    return {
        "data": image or video,
        "question": question,
    }
    """
    if args.modality == "image":
        # Input image and question
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
        img_question = "What is the content of this image?"

        return {
            "data": image,
            "question": img_question,
        }

    if args.modality == "video":
        # Input video and question
        video = VideoAsset(name="sample_demo_1.mp4",
                           num_frames=args.num_frames).np_ndarrays
        vid_question = "Why is this video funny?"

        return {
            "data": video,
            "question": vid_question,
        }

    msg = f"Modality {args.modality} is not supported."
    raise ValueError(msg)


610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
    """Repeats images with provided probability of "image_repeat_prob". 
    Used to simulate hit/miss for the MM preprocessor cache.
    """
    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
    no_yes = [0, 1]
    probs = [1.0 - image_repeat_prob, image_repeat_prob]

    inputs = []
    cur_image = data
    for i in range(num_prompts):
        if image_repeat_prob is not None:
            res = random.choices(no_yes, probs)[0]
            if res == 0:
                # No repeat => Modify one pixel
                cur_image = cur_image.copy()
                new_val = (i // 256 // 256, i // 256, i % 256)
                cur_image.putpixel((0, 0), new_val)

        inputs.append({
            "prompt": prompt,
            "multi_modal_data": {
                modality: cur_image
            }
        })

    return inputs


639
640
641
642
643
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

644
645
646
647
648
    modality = args.modality
    mm_input = get_multi_modal_input(args)
    data = mm_input["data"]
    question = mm_input["question"]

649
    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
650
651
652

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
653
654
655
    sampling_params = SamplingParams(temperature=0.2,
                                     max_tokens=64,
                                     stop_token_ids=stop_token_ids)
656
657
658
659
660
661
662

    assert args.num_prompts > 0
    if args.num_prompts == 1:
        # Single inference
        inputs = {
            "prompt": prompt,
            "multi_modal_data": {
663
                modality: data
664
665
666
667
668
            },
        }

    else:
        # Batch inference
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        if args.image_repeat_prob is not None:
            # Repeat images with specified probability of "image_repeat_prob"
            inputs = apply_image_repeat(args.image_repeat_prob,
                                        args.num_prompts, data, prompt,
                                        modality)
        else:
            # Use the same image for all prompts
            inputs = [{
                "prompt": prompt,
                "multi_modal_data": {
                    modality: data
                },
            } for _ in range(args.num_prompts)]

    if args.time_generate:
        import time
        start_time = time.time()
        outputs = llm.generate(inputs, sampling_params=sampling_params)
        elapsed_time = time.time() - start_time
        print("-- generate time = {}".format(elapsed_time))
689

690
691
    else:
        outputs = llm.generate(inputs, sampling_params=sampling_params)
692
693
694
695
696
697
698
699
700

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
Cyrus Leung's avatar
Cyrus Leung committed
701
        'vision language models for text generation')
702
703
704
705
706
707
708
709
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="llava",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
    parser.add_argument('--num-prompts',
                        type=int,
710
                        default=4,
711
                        help='Number of prompts to run.')
712
713
714
    parser.add_argument('--modality',
                        type=str,
                        default="image",
715
                        choices=['image', 'video'],
716
717
718
719
720
                        help='Modality of the input.')
    parser.add_argument('--num-frames',
                        type=int,
                        default=16,
                        help='Number of frames to extract from the video.')
721
722
723
724
725
726
727
728
729

    parser.add_argument(
        '--image-repeat-prob',
        type=float,
        default=None,
        help='Simulates the hit-ratio for multi-modal preprocessor cache'
        ' (if enabled)')

    parser.add_argument(
730
        '--disable-mm-preprocessor-cache',
731
        action='store_true',
732
        help='If True, disables caching of multi-modal preprocessor/mapper.')
733
734
735
736
737
738

    parser.add_argument(
        '--time-generate',
        action='store_true',
        help='If True, then print the total generate() call time')

739
740
    args = parser.parse_args()
    main(args)