audio_language.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from dataclasses import asdict
Patrick von Platen's avatar
Patrick von Platen committed
13
from typing import Any, NamedTuple, Optional
14
15

from huggingface_hub import snapshot_download
16
17
from transformers import AutoTokenizer

18
from vllm import LLM, EngineArgs, SamplingParams
19
from vllm.assets.audio import AudioAsset
20
from vllm.lora.request import LoRARequest
21
22
from vllm.utils import FlexibleArgumentParser

23
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
24
25
26
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
27
    2: "What sport and what nursery rhyme are referenced?",
28
}
29

30
31
32

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
    prompt: Optional[str] = None
    prompt_token_ids: Optional[dict[str, list[int]]] = None
    multi_modal_data: Optional[dict[str, Any]] = None
36
37
38
39
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


40
41
42
43
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

44

Patrick von Platen's avatar
Patrick von Platen committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.messages import (
        AudioChunk,
        RawAudio,
        TextChunk,
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
    model_name = "google/gemma-3n-E2B-it"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_batched_tokens=2048,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )
    prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
    "<end_of_turn>\n<start_of_turn>model\n"
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
    # NOTE - the setting in this example are somehat different than what is
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


149
# MiniCPM-O
150
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
151
    model_name = "openbmb/MiniCPM-o-2_6"
152
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
153
154
155
156
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
157
        max_num_seqs=2,
158
159
        limit_mm_per_prompt={"audio": audio_count},
    )
160

161
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
162
163
164
165
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
166
167
168
169
170
171
172
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
173
174
175
176
177
178

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
179
180


181
# Phi-4-multimodal-instruct
182
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
183
184
185
186
187
188
189
190
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
191
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
192

193
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
194

195
    engine_args = EngineArgs(
196
197
        model=model_path,
        trust_remote_code=True,
198
        max_model_len=12800,
199
200
201
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
202
        limit_mm_per_prompt={"audio": audio_count},
203
204
    )

205
206
207
208
209
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
210
211


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData:
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download(
        "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
    )
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
    placeholders = "<|audio|>" * audio_count

    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"

    engine_args = EngineArgs(
        model=model_path,
        max_model_len=12800,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


243
# Qwen2-Audio
244
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
245
246
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

247
248
249
250
251
252
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
253

254
255
256
257
258
259
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
260

261
262
263
264
265
266
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
267
268
269
270
271

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
272
273


274
275
276
277
278
279
280
281
282
283
284
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

285
286
287
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
288
289
290
291

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
292
293
        "generating text and speech."
    )
294

295
296
297
298
299
300
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
301
302
303
304
305
306
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


307
# Ultravox 0.5-1B
308
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
309
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
310

311
    tokenizer = AutoTokenizer.from_pretrained(model_name)
312
313
314
315
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
316

317
318
319
320
321
322
323
324
325
326
327
328
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
329
330
331


# Whisper
332
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
333
    assert audio_count == 1, "Whisper only support single audio input per prompt"
334
335
336
337
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

338
339
340
341
342
343
344
345
346
347
348
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
349
350
351


model_example_map = {
Patrick von Platen's avatar
Patrick von Platen committed
352
    "voxtral": run_voxtral,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
353
    "gemma3n": run_gemma3n,
354
    "granite_speech": run_granite_speech,
355
    "minicpmo": run_minicpmo,
356
    "phi4_mm": run_phi4mm,
357
    "phi4_multimodal": run_phi4_multimodal,
358
    "qwen2_audio": run_qwen2_audio,
359
    "qwen2_5_omni": run_qwen2_5_omni,
360
361
    "ultravox": run_ultravox,
    "whisper": run_whisper,
362
}
363
364


365
366
def parse_args():
    parser = FlexibleArgumentParser(
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Set the seed when initializing `vllm.LLM`.",
    )
394
395
396
397

    return parser.parse_args()


398
399
400
401
402
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

403
    audio_count = args.num_audios
404
405
406
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
407

408
409
410
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
411
412
        req_data.engine_args.limit_mm_per_prompt or {}
    )
413

414
415
416
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
    llm = LLM(**engine_args)

417
418
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
419
420
421
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
422

Patrick von Platen's avatar
Patrick von Platen committed
423
424
425
426
427
428
429
430
431
    mm_data = req_data.multi_modal_data
    if not mm_data:
        mm_data = {}
        if audio_count > 0:
            mm_data = {
                "audio": [
                    asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
                ]
            }
432
433

    assert args.num_prompts > 0
Patrick von Platen's avatar
Patrick von Platen committed
434
435
436
437
438
439
440
    inputs = {"multi_modal_data": mm_data}

    if req_data.prompt:
        inputs["prompt"] = req_data.prompt
    else:
        inputs["prompt_token_ids"] = req_data.prompt_token_ids

441
    if args.num_prompts > 1:
442
        # Batch inference
443
        inputs = [inputs] * args.num_prompts
444
    # Add LoRA request if applicable
445
446
447
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
448
449
450
451
452
453

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
454
455
456
457
458
459
460

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
461
    args = parse_args()
462
    main(args)