encoder_decoder_multimodal.py 7.23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
7

8
import time
9
10
11
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
12

13
from vllm import LLM, EngineArgs, PromptType, SamplingParams
14
15
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
汪志鹏's avatar
汪志鹏 committed
16
from vllm.multimodal.utils import fetch_image
17
18
19
from vllm.utils import FlexibleArgumentParser


20
21
22
23
24
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


汪志鹏's avatar
汪志鹏 committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def run_donut():
    engine_args = EngineArgs(
        model="naver-clova-ix/donut-base-finetuned-docvqa",
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="float16",
        hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
    )

    # The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
    # and the patch_size is 4 x 4.
    # Therefore, the initial number of patches is:
    # Height: 1920 / 4 = 480 patches
    # Width: 2560 / 4 = 640 patches
    # The Swin model uses a staged downsampling approach,
    # defined by the "depths": [2, 2, 14, 2] configuration.
    # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
    # which halves the feature map's dimensions (dividing both height and width by 2).
    # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
    # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
    # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
    # Because vLLM needs to fill the image features with an encoder_prompt,
    # and the encoder_prompt will have `<pad>` tokens added when tokenized,
    # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
    prompts = [
        {
            "encoder_prompt": {
                "prompt": "".join(["$"] * 4799),
                "multi_modal_data": {
                    "image": fetch_image(
                        "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
                    )  # noqa: E501
                },
            },
            "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>",  # noqa: E501
        },
    ]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


69
def run_florence2():
70
    engine_args = EngineArgs(
71
        model="microsoft/Florence-2-large",
72
        tokenizer="Isotr0py/Florence-2-tokenizer",
73
74
75
76
77
78
79
        max_num_seqs=8,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
80
        {  # implicit prompt with task token
81
            "prompt": "<DETAILED_CAPTION>",
82
            "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
83
        },
84
        {  # explicit encoder/decoder prompt
85
86
            "encoder_prompt": {
                "prompt": "Describe in detail what is shown in the image.",
87
                "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
88
89
90
91
            },
            "decoder_prompt": "",
        },
    ]
92
93
94
95
96

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
97
98
99


def run_mllama():
100
    engine_args = EngineArgs(
101
        model="meta-llama/Llama-3.2-11B-Vision-Instruct",
102
        max_model_len=8192,
103
104
105
106
107
108
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
109
110
        {  # Implicit prompt
            "prompt": "<|image|><|begin_of_text|>What is the content of this image?",  # noqa: E501
111
112
113
114
            "multi_modal_data": {
                "image": ImageAsset("stop_sign").pil_image,
            },
        },
115
        {  # Explicit prompt
116
117
118
119
120
121
            "encoder_prompt": {
                "prompt": "<|image|>",
                "multi_modal_data": {
                    "image": ImageAsset("stop_sign").pil_image,
                },
            },
122
            "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.",  # noqa: E501
123
124
        },
    ]
125
126
127
128
129

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
130
131
132


def run_whisper():
133
    engine_args = EngineArgs(
134
135
136
137
138
139
140
141
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
142
        {  # Test implicit prompt
143
144
145
146
147
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
148
        {  # Test explicit encoder/decoder prompt
149
150
151
152
153
154
155
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
156
        },
157
    ]
158
159
160
161
162

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
163
164
165


model_example_map = {
汪志鹏's avatar
汪志鹏 committed
166
    "donut": run_donut,
167
168
169
170
171
172
    "florence2": run_florence2,
    "mllama": run_mllama,
    "whisper": run_whisper,
}


173
174
def parse_args():
    parser = FlexibleArgumentParser(
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="mllama",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
192
193
194
    return parser.parse_args()


195
196
197
198
199
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

200
201
    req_data = model_example_map[model]()

202
203
204
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
205
206
        req_data.engine_args.limit_mm_per_prompt or {}
    )
207

208
209
210
211
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    prompts = req_data.prompts
212
213
214
215
216
217

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
218
        skip_special_tokens=False,
219
220
221
222
223
224
225
226
227
228
229
230
231
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
232
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
233
234
235
236
237
238
239
240

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
241
    args = parse_args()
242
    main(args)