encoder_decoder_multimodal.py 7.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
7

8
import os
9
import time
10
11
12
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
13

14
from vllm import LLM, EngineArgs, PromptType, SamplingParams
15
16
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
汪志鹏's avatar
汪志鹏 committed
17
from vllm.multimodal.utils import fetch_image
18
19
20
from vllm.utils import FlexibleArgumentParser


21
22
23
24
25
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


汪志鹏's avatar
汪志鹏 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def run_donut():
    engine_args = EngineArgs(
        model="naver-clova-ix/donut-base-finetuned-docvqa",
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="float16",
        hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
    )

    # The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
    # and the patch_size is 4 x 4.
    # Therefore, the initial number of patches is:
    # Height: 1920 / 4 = 480 patches
    # Width: 2560 / 4 = 640 patches
    # The Swin model uses a staged downsampling approach,
    # defined by the "depths": [2, 2, 14, 2] configuration.
    # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
    # which halves the feature map's dimensions (dividing both height and width by 2).
    # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
    # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
    # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
    # Because vLLM needs to fill the image features with an encoder_prompt,
    # and the encoder_prompt will have `<pad>` tokens added when tokenized,
    # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
    prompts = [
        {
            "encoder_prompt": {
                "prompt": "".join(["$"] * 4799),
                "multi_modal_data": {
                    "image": fetch_image(
                        "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
                    )  # noqa: E501
                },
            },
            "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>",  # noqa: E501
        },
    ]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


70
def run_florence2():
71
    engine_args = EngineArgs(
72
        model="microsoft/Florence-2-large",
73
        tokenizer="Isotr0py/Florence-2-tokenizer",
74
75
76
77
78
79
80
        max_num_seqs=8,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
81
        {  # implicit prompt with task token
82
            "prompt": "<DETAILED_CAPTION>",
83
            "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
84
        },
85
        {  # explicit encoder/decoder prompt
86
87
            "encoder_prompt": {
                "prompt": "Describe in detail what is shown in the image.",
88
                "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
89
90
91
92
            },
            "decoder_prompt": "",
        },
    ]
93
94
95
96
97

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
98
99
100


def run_mllama():
101
    engine_args = EngineArgs(
102
        model="meta-llama/Llama-3.2-11B-Vision-Instruct",
103
        max_model_len=8192,
104
105
106
107
108
109
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
110
111
        {  # Implicit prompt
            "prompt": "<|image|><|begin_of_text|>What is the content of this image?",  # noqa: E501
112
113
114
115
            "multi_modal_data": {
                "image": ImageAsset("stop_sign").pil_image,
            },
        },
116
        {  # Explicit prompt
117
118
119
120
121
122
            "encoder_prompt": {
                "prompt": "<|image|>",
                "multi_modal_data": {
                    "image": ImageAsset("stop_sign").pil_image,
                },
            },
123
            "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.",  # noqa: E501
124
125
        },
    ]
126
127
128
129
130

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
131
132
133


def run_whisper():
134
135
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

136
    engine_args = EngineArgs(
137
138
139
140
141
142
143
144
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
145
        {  # Test implicit prompt
146
147
148
149
150
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
151
        {  # Test explicit encoder/decoder prompt
152
153
154
155
156
157
158
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
159
        },
160
    ]
161
162
163
164
165

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
166
167
168


model_example_map = {
汪志鹏's avatar
汪志鹏 committed
169
    "donut": run_donut,
170
171
172
173
174
175
    "florence2": run_florence2,
    "mllama": run_mllama,
    "whisper": run_whisper,
}


176
177
def parse_args():
    parser = FlexibleArgumentParser(
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="mllama",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
195
196
197
    return parser.parse_args()


198
199
200
201
202
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

203
204
    req_data = model_example_map[model]()

205
206
207
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
208
209
        req_data.engine_args.limit_mm_per_prompt or {}
    )
210

211
212
213
214
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    prompts = req_data.prompts
215
216
217
218
219
220

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
221
        skip_special_tokens=False,
222
223
224
225
226
227
228
229
230
231
232
233
234
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
235
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
236
237
238
239
240
241
242
243

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
244
    args = parse_args()
245
    main(args)