audio_language.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from dataclasses import asdict
Patrick von Platen's avatar
Patrick von Platen committed
13
from typing import Any, NamedTuple, Optional
14
15

from huggingface_hub import snapshot_download
16
17
from transformers import AutoTokenizer

18
from vllm import LLM, EngineArgs, SamplingParams
19
from vllm.assets.audio import AudioAsset
20
from vllm.lora.request import LoRARequest
21
22
from vllm.utils import FlexibleArgumentParser

23
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
24
25
26
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
27
    2: "What sport and what nursery rhyme are referenced?",
28
}
29

30
31
32

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
    prompt: Optional[str] = None
    prompt_token_ids: Optional[dict[str, list[int]]] = None
    multi_modal_data: Optional[dict[str, Any]] = None
36
37
38
39
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


40
41
42
43
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

44

Patrick von Platen's avatar
Patrick von Platen committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.messages import (
        AudioChunk,
        RawAudio,
        TextChunk,
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
    # NOTE - the setting in this example are somehat different than what is
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


130
# MiniCPM-O
131
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
132
    model_name = "openbmb/MiniCPM-o-2_6"
133
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
134
135
136
137
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
138
        max_num_seqs=2,
139
140
        limit_mm_per_prompt={"audio": audio_count},
    )
141

142
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
143
144
145
146
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
147
148
149
150
151
152
153
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
154
155
156
157
158
159

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
160
161


162
# Phi-4-multimodal-instruct
163
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
164
165
166
167
168
169
170
171
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
172
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
173

174
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
175

176
    engine_args = EngineArgs(
177
178
        model=model_path,
        trust_remote_code=True,
179
        max_model_len=12800,
180
181
182
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
183
        limit_mm_per_prompt={"audio": audio_count},
184
185
    )

186
187
188
189
190
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
191
192


193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData:
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download(
        "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
    )
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
    placeholders = "<|audio|>" * audio_count

    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"

    engine_args = EngineArgs(
        model=model_path,
        max_model_len=12800,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


224
# Qwen2-Audio
225
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
226
227
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

228
229
230
231
232
233
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
234

235
236
237
238
239
240
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
241

242
243
244
245
246
247
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
248
249
250
251
252

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
253
254


255
256
257
258
259
260
261
262
263
264
265
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

266
267
268
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
269
270
271
272

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
273
274
        "generating text and speech."
    )
275

276
277
278
279
280
281
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
282
283
284
285
286
287
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


288
# Ultravox 0.5-1B
289
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
290
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
291

292
    tokenizer = AutoTokenizer.from_pretrained(model_name)
293
294
295
296
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
297

298
299
300
301
302
303
304
305
306
307
308
309
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
310
311
312


# Whisper
313
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
314
    assert audio_count == 1, "Whisper only support single audio input per prompt"
315
316
317
318
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

319
320
321
322
323
324
325
326
327
328
329
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
330
331
332


model_example_map = {
Patrick von Platen's avatar
Patrick von Platen committed
333
    "voxtral": run_voxtral,
334
    "granite_speech": run_granite_speech,
335
    "minicpmo": run_minicpmo,
336
    "phi4_mm": run_phi4mm,
337
    "phi4_multimodal": run_phi4_multimodal,
338
    "qwen2_audio": run_qwen2_audio,
339
    "qwen2_5_omni": run_qwen2_5_omni,
340
341
    "ultravox": run_ultravox,
    "whisper": run_whisper,
342
}
343
344


345
346
def parse_args():
    parser = FlexibleArgumentParser(
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
374
375
376
377

    return parser.parse_args()


378
379
380
381
382
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

383
    audio_count = args.num_audios
384
385
386
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
387

388
389
390
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
391
392
        req_data.engine_args.limit_mm_per_prompt or {}
    )
393

394
395
396
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

397
398
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
399
400
401
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
402

Patrick von Platen's avatar
Patrick von Platen committed
403
404
405
406
407
408
409
410
411
    mm_data = req_data.multi_modal_data
    if not mm_data:
        mm_data = {}
        if audio_count > 0:
            mm_data = {
                "audio": [
                    asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
                ]
            }
412
413

    assert args.num_prompts > 0
Patrick von Platen's avatar
Patrick von Platen committed
414
415
416
417
418
419
420
    inputs = {"multi_modal_data": mm_data}

    if req_data.prompt:
        inputs["prompt"] = req_data.prompt
    else:
        inputs["prompt_token_ids"] = req_data.prompt_token_ids

421
    if args.num_prompts > 1:
422
        # Batch inference
423
        inputs = [inputs] * args.num_prompts
424
    # Add LoRA request if applicable
425
426
427
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
428
429
430
431
432
433

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
434
435
436
437
438
439
440

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
441
    args = parse_args()
442
    main(args)