"vscode:/vscode.git/clone" did not exist on "d151fde8341d34592e1e5e14d2152d067421cf63"
offline_inference_vision_language_multi_image.py 4.99 KB
Newer Older
1
2
3
4
5
6
7
8
"""
This example shows how to use vLLM for running offline inference with
multi-image input on vision language models, using the chat template defined
by the model.
"""
from argparse import Namespace
from typing import List

9
10
11
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
12
13
14
15
16
17
18
19
20
21
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser

QUESTION = "What is the content of each image?"
IMAGE_URLS = [
    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
]


22
23
def load_phi3v(question, image_urls: List[str]):
    llm = LLM(
24
25
26
27
28
29
30
31
        model="microsoft/Phi-3.5-vision-instruct",
        trust_remote_code=True,
        max_model_len=4096,
        limit_mm_per_prompt={"image": len(image_urls)},
    )
    placeholders = "\n".join(f"<|image_{i}|>"
                             for i, _ in enumerate(image_urls, start=1))
    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
32
33
    stop_token_ids = None
    return llm, prompt, stop_token_ids
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def load_internvl(question, image_urls: List[str]):
    model_name = "OpenGVLab/InternVL2-2B"

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        max_num_seqs=5,
        max_model_len=4096,
        limit_mm_per_prompt={"image": len(image_urls)},
    )

    placeholders = "\n".join(f"Image-{i}: <image>\n"
                             for i, _ in enumerate(image_urls, start=1))
    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              trust_remote_code=True)
    prompt = tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    return llm, prompt, stop_token_ids


model_example_map = {
    "phi3_v": load_phi3v,
    "internvl_chat": load_internvl,
}


def run_generate(model, question: str, image_urls: List[str]):
    llm, prompt, stop_token_ids = model_example_map[model](question,
                                                           image_urls)

    sampling_params = SamplingParams(temperature=0.0,
                                     max_tokens=128,
                                     stop_token_ids=stop_token_ids)

    outputs = llm.generate(
        {
            "prompt": prompt,
            "multi_modal_data": {
                "image": [fetch_image(url) for url in image_urls]
            },
86
        },
87
        sampling_params=sampling_params)
88
89
90
91
92
93

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


94
95
96
97
98
99
def run_chat(model: str, question: str, image_urls: List[str]):
    llm, _, stop_token_ids = model_example_map[model](question, image_urls)

    sampling_params = SamplingParams(temperature=0.0,
                                     max_tokens=128,
                                     stop_token_ids=stop_token_ids)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    outputs = llm.chat([{
        "role":
        "user",
        "content": [
            {
                "type": "text",
                "text": question,
            },
            *({
                "type": "image_url",
                "image_url": {
                    "url": image_url
                },
            } for image_url in image_urls),
        ],
116
117
    }],
                       sampling_params=sampling_params)
118
119
120
121
122
123
124

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


def main(args: Namespace):
125
    model = args.model_type
126
127
128
    method = args.method

    if method == "generate":
129
        run_generate(model, QUESTION, IMAGE_URLS)
130
    elif method == "chat":
131
        run_chat(model, QUESTION, IMAGE_URLS)
132
133
134
135
136
137
138
139
    else:
        raise ValueError(f"Invalid method: {method}")


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Demo on using vLLM for offline inference with '
        'vision language models that support multi-image input')
140
141
142
143
144
145
    parser.add_argument('--model-type',
                        '-m',
                        type=str,
                        default="phi3_v",
                        choices=model_example_map.keys(),
                        help='Huggingface "model_type".')
146
147
148
149
150
151
152
153
    parser.add_argument("--method",
                        type=str,
                        default="generate",
                        choices=["generate", "chat"],
                        help="The method to run in `vllm.LLM`.")

    args = parser.parse_args()
    main(args)