"tests/vscode:/vscode.git/clone" did not exist on "7e8977fcd4e9c3bf6b114c7dc715b28a61b5cdb0"
encoder_decoder_multimodal.py 5.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
7

8
import time
9
10
11
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
12

13
from vllm import LLM, EngineArgs, PromptType, SamplingParams
14
15
16
17
18
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.utils import FlexibleArgumentParser


19
20
21
22
23
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


24
def run_florence2():
25
    engine_args = EngineArgs(
26
        model="microsoft/Florence-2-large",
27
        tokenizer="Isotr0py/Florence-2-tokenizer",
28
29
30
31
32
33
34
        max_num_seqs=8,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
35
        {  # implicit prompt with task token
36
            "prompt": "<DETAILED_CAPTION>",
37
            "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
38
        },
39
        {  # explicit encoder/decoder prompt
40
41
            "encoder_prompt": {
                "prompt": "Describe in detail what is shown in the image.",
42
                "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
43
44
45
46
            },
            "decoder_prompt": "",
        },
    ]
47
48
49
50
51

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
52
53
54


def run_mllama():
55
    engine_args = EngineArgs(
56
        model="meta-llama/Llama-3.2-11B-Vision-Instruct",
57
        max_model_len=8192,
58
59
60
61
62
63
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
64
65
        {  # Implicit prompt
            "prompt": "<|image|><|begin_of_text|>What is the content of this image?",  # noqa: E501
66
67
68
69
            "multi_modal_data": {
                "image": ImageAsset("stop_sign").pil_image,
            },
        },
70
        {  # Explicit prompt
71
72
73
74
75
76
            "encoder_prompt": {
                "prompt": "<|image|>",
                "multi_modal_data": {
                    "image": ImageAsset("stop_sign").pil_image,
                },
            },
77
            "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.",  # noqa: E501
78
79
        },
    ]
80
81
82
83
84

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
85
86
87


def run_whisper():
88
    engine_args = EngineArgs(
89
90
91
92
93
94
95
96
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
97
        {  # Test implicit prompt
98
99
100
101
102
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
103
        {  # Test explicit encoder/decoder prompt
104
105
106
107
108
109
110
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
111
        },
112
    ]
113
114
115
116
117

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
118
119
120
121
122
123
124
125
126


model_example_map = {
    "florence2": run_florence2,
    "mllama": run_mllama,
    "whisper": run_whisper,
}


127
128
def parse_args():
    parser = FlexibleArgumentParser(
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="mllama",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
146
147
148
    return parser.parse_args()


149
150
151
152
153
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

154
155
    req_data = model_example_map[model]()

156
157
158
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
159
160
        req_data.engine_args.limit_mm_per_prompt or {}
    )
161

162
163
164
165
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    prompts = req_data.prompts
166
167
168
169
170
171

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
172
        skip_special_tokens=False,
173
174
175
176
177
178
179
180
181
182
183
184
185
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
186
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
187
188
189
190
191
192
193
194

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
195
    args = parse_args()
196
    main(args)