encoder_decoder_multimodal.py 5.81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""
This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
7

8
import os
9
import time
10
11
from collections.abc import Sequence
from typing import NamedTuple
12

13
from vllm import LLM, EngineArgs, PromptType, SamplingParams
14
from vllm.assets.audio import AudioAsset
15
from vllm.utils.argparse_utils import FlexibleArgumentParser
16
17


18
19
20
21
22
class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: Sequence[PromptType]


23
def run_whisper():
24
25
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

26
    engine_args = EngineArgs(
27
28
29
30
31
32
33
34
        model="openai/whisper-large-v3-turbo",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
        dtype="half",
    )

    prompts = [
35
        {  # Test implicit prompt
36
37
38
39
40
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
41
        {  # Test explicit encoder/decoder prompt
42
43
44
45
46
47
48
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<|startoftranscript|>",
49
        },
50
    ]
51
52
53
54
55

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )
56
57


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def run_fireredasr2():
    """
    FireRedASR2 – Automatic Speech Recognition model.

    This model uses a Conformer encoder + Qwen2 LLM decoder architecture
    for speech-to-text transcription.  Audio is passed via the implicit
    prompt format with the ``<|AUDIO|>`` placeholder token.
    """
    engine_args = EngineArgs(
        model="allendou/FireRedASR2-LLM-vllm",
        max_model_len=448,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
    )

    prompt_str = (
        "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n"
    )

    prompts = [
        {  # Implicit prompt with audio
            "prompt": prompt_str,
            "multi_modal_data": {
                "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
            },
        },
        {  # Another audio sample
            "prompt": prompt_str,
            "multi_modal_data": {
                "audio": AudioAsset("winning_call").audio_and_sample_rate,
            },
        },
    ]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


def run_fireredlid():
    """
    FireRedLID – Language Identification model.

    This encoder-decoder model identifies the spoken language of an audio
    clip. It outputs at most 2 tokens representing the detected language
    (e.g. "en", "zh mandarin").
    """
    engine_args = EngineArgs(
        model="PatchyTisa/FireRedLID-vllm",
        max_model_len=8,
        max_num_seqs=16,
        limit_mm_per_prompt={"audio": 1},
    )

    prompts = [
        {  # Test explicit encoder/decoder prompt
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<sos>",
        },
        {  # Another audio sample
            "encoder_prompt": {
                "prompt": "",
                "multi_modal_data": {
                    "audio": AudioAsset("winning_call").audio_and_sample_rate,
                },
            },
            "decoder_prompt": "<sos>",
        },
    ]

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
    )


140
model_example_map = {
141
142
    "fireredasr2": run_fireredasr2,
    "fireredlid": run_fireredlid,
143
144
145
146
    "whisper": run_whisper,
}


147
148
def parse_args():
    parser = FlexibleArgumentParser(
149
150
151
152
153
154
155
        description="Demo on using vLLM for offline inference with "
        "vision language models for text generation"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
156
        default="whisper",
157
158
159
160
161
162
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--seed",
        type=int,
163
        default=0,
164
165
        help="Set the seed when initializing `vllm.LLM`.",
    )
166
167
168
    return parser.parse_args()


169
170
171
172
173
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

174
175
    req_data = model_example_map[model]()

176
    # Disable other modalities to save memory
177
    engine_args = req_data.engine_args
178
    default_limits = {"image": 0, "video": 0, "audio": 0}
179
180
181
182
    limit_mm_per_prompt = default_limits | (engine_args.limit_mm_per_prompt or {})
    engine_args.limit_mm_per_prompt = limit_mm_per_prompt
    engine_args.seed = args.seed
    llm = LLM.from_engine_args(engine_args)
183
184

    prompts = req_data.prompts
185
186
187
188
189
190

    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0,
        top_p=1.0,
        max_tokens=64,
191
        skip_special_tokens=False,
192
193
194
195
196
197
198
199
200
201
202
203
204
    )

    start = time.time()

    # Generate output tokens from the prompts. The output is a list of
    # RequestOutput objects that contain the prompt, generated
    # text, and other information.
    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
205
        print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
206
207
208
209
210
211
212
213

    duration = time.time() - start

    print("Duration:", duration)
    print("RPS:", len(prompts) / duration)


if __name__ == "__main__":
214
    args = parse_args()
215
    main(args)