audio_language.py 18.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from dataclasses import asdict
13
from typing import Any, NamedTuple
14
15

from huggingface_hub import snapshot_download
16
17
from transformers import AutoTokenizer

18
from vllm import LLM, EngineArgs, SamplingParams
19
from vllm.assets.audio import AudioAsset
20
from vllm.lora.request import LoRARequest
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22

23
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
24
25
26
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
27
    2: "What sport and what nursery rhyme are referenced?",
28
}
29

30
31
32

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
33
34
35
36
37
    prompt: str | None = None
    prompt_token_ids: dict[str, list[int]] | None = None
    multi_modal_data: dict[str, Any] | None = None
    stop_token_ids: list[int] | None = None
    lora_requests: list[LoRARequest] | None = None
38
39


40
41
42
43
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

44

45
46
47
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/audio-flamingo-3-hf"
Patrick von Platen's avatar
Patrick von Platen committed
48
49
    engine_args = EngineArgs(
        model=model_name,
50
        max_model_len=4096,
Patrick von Platen's avatar
Patrick von Platen committed
51
52
53
54
55
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

56
57
    # AudioFlamingo3 uses <sound> token for audio
    audio_placeholder = "<sound>" * audio_count
Patrick von Platen's avatar
Patrick von Platen committed
58

59
60
61
62
63
64
65
    prompt = (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
Patrick von Platen's avatar
Patrick von Platen committed
66
67
68

    return ModelRequestData(
        engine_args=engine_args,
69
        prompt=prompt,
Patrick von Platen's avatar
Patrick von Platen committed
70
71
72
    )


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/music-flamingo-2601-hf"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

    # MusicFlamingo uses <sound> token for audio
    audio_placeholder = "<sound>" * audio_count

    prompt = (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
    model_name = "google/gemma-3n-E2B-it"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_batched_tokens=2048,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )
    prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
    "<end_of_turn>\n<start_of_turn>model\n"
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "zai-org/GLM-ASR-Nano-2512"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # GLM-ASR uses <|pad|> token for audio
    audio_placeholder = "<|pad|>" * audio_count

    messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
    # NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
    # writing. Pass a local model path via `--model`.
    model_name = "funaudiochat"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
    )
    prompt = f"{audio_in_prompt}{question}"

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


173
174
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
175
    # NOTE - the setting in this example are somewhat different from what is
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
    model_name = "mispeech/midashenglm-7b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
    )

    default_system = "You are a helpful language and speech assistant."

    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


234
# MiniCPM-O
235
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
236
    model_name = "openbmb/MiniCPM-o-2_6"
237
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
238
239
240
241
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
242
        max_num_seqs=2,
243
244
        limit_mm_per_prompt={"audio": audio_count},
    )
245

246
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
247
248
249
250
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
251
252
253
254
255
256
257
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
258
259
260
261
262
263

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
264
265


266
# Phi-4-multimodal-instruct
267
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
268
269
270
271
272
273
274
275
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
276
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
277

278
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
279

280
    engine_args = EngineArgs(
281
282
        model=model_path,
        trust_remote_code=True,
283
        max_model_len=12800,
284
285
286
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
287
        limit_mm_per_prompt={"audio": audio_count},
288
289
    )

290
291
292
293
294
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
295
296


297
# Qwen2-Audio
298
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
299
300
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

301
302
303
304
305
306
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
307

308
309
310
311
312
313
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
314

315
316
317
318
319
320
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
321
322
323
324
325

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
326
327


328
329
330
331
332
333
334
335
336
337
338
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

339
340
341
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
342
343
344
345

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
346
347
        "generating text and speech."
    )
348

349
350
351
352
353
354
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
355
356
357
358
359
360
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Roger Wang's avatar
Roger Wang committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "Qwen/Qwen3-Asr-1.7B"

    audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
    prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


380
# Ultravox 0.5-1B
381
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
382
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
383

384
    tokenizer = AutoTokenizer.from_pretrained(model_name)
385
386
387
388
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
389

390
391
392
393
394
395
396
397
398
399
400
401
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
402
403


404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.chunk import (
        AudioChunk,
        RawAudio,
        TextChunk,
    )
    from mistral_common.protocol.instruct.messages import (
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


461
# Whisper
462
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
463
    assert audio_count == 1, "Whisper only support single audio input per prompt"
464
465
466
467
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

468
469
470
471
472
473
474
475
476
477
478
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
479
480
481


model_example_map = {
482
    "audioflamingo3": run_audioflamingo3,
483
    "musicflamingo": run_musicflamingo,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
484
    "gemma3n": run_gemma3n,
485
    "glmasr": run_glmasr,
486
    "funaudiochat": run_funaudiochat,
487
    "granite_speech": run_granite_speech,
488
    "midashenglm": run_midashenglm,
489
    "minicpmo": run_minicpmo,
490
    "phi4_mm": run_phi4mm,
491
    "qwen2_audio": run_qwen2_audio,
492
    "qwen2_5_omni": run_qwen2_5_omni,
Roger Wang's avatar
Roger Wang committed
493
    "qwen3_asr": run_qwen3_asr,
494
    "ultravox": run_ultravox,
495
    "voxtral": run_voxtral,
496
    "whisper": run_whisper,
497
}
498
499


500
501
def parse_args():
    parser = FlexibleArgumentParser(
502
503
504
505
506
507
508
509
510
511
512
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
513
514
515
516
517
518
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model ID or local path override. Required for funaudiochat.",
    )
519
520
521
522
523
524
525
526
527
528
529
530
531
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
532
        default=0,
533
534
        help="Set the seed when initializing `vllm.LLM`.",
    )
535
536
537
538
539
540
541
    parser.add_argument(
        "--tensor-parallel-size",
        "-tp",
        type=int,
        default=None,
        help="Tensor parallel size to override the model's default setting. ",
    )
542
543
544
545

    return parser.parse_args()


546
547
548
549
550
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

551
552
553
    if model == "funaudiochat" and not args.model:
        raise ValueError("--model is required when --model-type=funaudiochat")

554
555
556
557
558
559
    if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
        raise ValueError(
            f"tensor_parallel_size must be a positive integer, "
            f"got {args.tensor_parallel_size}"
        )

560
    audio_count = args.num_audios
561
562
563
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
564
565
    if model == "funaudiochat":
        req_data.engine_args.model = args.model
566

567
568
569
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
570
571
        req_data.engine_args.limit_mm_per_prompt or {}
    )
572

573
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
574
575
    if args.tensor_parallel_size is not None:
        engine_args["tensor_parallel_size"] = args.tensor_parallel_size
576
577
    llm = LLM(**engine_args)

578
579
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
580
581
582
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    def get_input(start, end):
        mm_data = req_data.multi_modal_data
        if not mm_data:
            mm_data = {}
            if end - start > 0:
                mm_data = {
                    "audio": [
                        asset.audio_and_sample_rate for asset in audio_assets[start:end]
                    ]
                }

        inputs = {"multi_modal_data": mm_data}

        if req_data.prompt:
            inputs["prompt"] = req_data.prompt
        else:
            inputs["prompt_token_ids"] = req_data.prompt_token_ids

        return inputs

    # Batch inference
605
    assert args.num_prompts > 0
606
607
608
    if audio_count != 1:
        inputs = get_input(0, audio_count)
        inputs = [inputs] * args.num_prompts
Patrick von Platen's avatar
Patrick von Platen committed
609
    else:
610
611
612
613
614
615
616
        # For single audio input, we need to vary the audio input
        # to avoid deduplication in vLLM engine.
        inputs = []
        for i in range(args.num_prompts):
            start = i % len(audio_assets)
            inp = get_input(start, start + 1)
            inputs.append(inp)
Patrick von Platen's avatar
Patrick von Platen committed
617

618
    # Add LoRA request if applicable
619
620
621
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
622
623
624
625
626
627

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
628
629
630
631
632
633
634

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
635
    args = parse_args()
636
    main(args)