encoder_decoder_multimodal.py 5.22 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
6

7
import time
8
9
10
from collections.abc import Sequence
from dataclasses import asdict
from typing import NamedTuple
11

12
from vllm import LLM, EngineArgs, PromptType, SamplingParams
13
14
15
16
17
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.utils import FlexibleArgumentParser


18
19
20
21
22
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


23
def run_florence2():
24
    engine_args = EngineArgs(
25
        model="microsoft/Florence-2-large",
26
        tokenizer="Isotr0py/Florence-2-tokenizer",
27
28
29
30
31
32
33
        max_num_seqs=8,
        trust_remote_code=True,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
34
        {  # implicit prompt with task token
35
            "prompt": "<DETAILED_CAPTION>",
36
            "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
37
        },
38
        {  # explicit encoder/decoder prompt
39
40
            "encoder_prompt": {
                "prompt": "Describe in detail what is shown in the image.",
41
                "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image},
42
43
44
45
            },
            "decoder_prompt": "",
        },
    ]
46
47
48
49
50

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
51
52
53


def run_mllama():
54
    engine_args = EngineArgs(
55
        model="meta-llama/Llama-3.2-11B-Vision-Instruct",
56
        max_model_len=8192,
57
58
59
60
61
62
        max_num_seqs=2,
        limit_mm_per_prompt={"image": 1},
        dtype="half",
    )

    prompts = [
63
64
        {  # Implicit prompt
            "prompt": "<|image|><|begin_of_text|>What is the content of this image?",  # noqa: E501
65
66
67
68
            "multi_modal_data": {
                "image": ImageAsset("stop_sign").pil_image,
            },
        },
69
        {  # Explicit prompt
70
71
72
73
74
75
            "encoder_prompt": {
                "prompt": "<|image|>",
                "multi_modal_data": {
                    "image": ImageAsset("stop_sign").pil_image,
                },
            },
76
            "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.",  # noqa: E501
77
78
        },
    ]
79
80
81
82
83

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
84
85
86


def run_whisper():
87
    engine_args = EngineArgs(
88
89
90
91
92
93
94
95
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
96
        {  # Test implicit prompt
97
98
99
100
101
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
102
        {  # Test explicit encoder/decoder prompt
103
104
105
106
107
108
109
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
110
        },
111
    ]
112
113
114
115
116

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
117
118
119
120
121
122
123
124
125


model_example_map = {
    "florence2": run_florence2,
    "mllama": run_mllama,
    "whisper": run_whisper,
}


126
127
def parse_args():
    parser = FlexibleArgumentParser(
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="mllama",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
145
146
147
    return parser.parse_args()


148
149
150
151
152
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

153
154
    req_data = model_example_map[model]()

155
156
157
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
158
159
        req_data.engine_args.limit_mm_per_prompt or {}
    )
160

161
162
163
164
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

    prompts = req_data.prompts
165
166
167
168
169
170

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
171
        skip_special_tokens=False,
172
173
174
175
176
177
178
179
180
181
182
183
184
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
185
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
186
187
188
189
190
191
192
193

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
194
    args = parse_args()
195
    main(args)