audio_language.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from dataclasses import asdict
13
from typing import Any, NamedTuple
14
15

from huggingface_hub import snapshot_download
16
17
from transformers import AutoTokenizer

18
from vllm import LLM, EngineArgs, SamplingParams
19
from vllm.assets.audio import AudioAsset
20
from vllm.lora.request import LoRARequest
21
from vllm.utils.argparse_utils import FlexibleArgumentParser
22

23
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
24
25
26
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
27
    2: "What sport and what nursery rhyme are referenced?",
28
}
29

30
31
32

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
33
34
35
36
37
    prompt: str | None = None
    prompt_token_ids: dict[str, list[int]] | None = None
    multi_modal_data: dict[str, Any] | None = None
    stop_token_ids: list[int] | None = None
    lora_requests: list[LoRARequest] | None = None
38
39


40
41
42
43
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

44

45
46
47
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/audio-flamingo-3-hf"
Patrick von Platen's avatar
Patrick von Platen committed
48
49
    engine_args = EngineArgs(
        model=model_name,
50
        max_model_len=4096,
Patrick von Platen's avatar
Patrick von Platen committed
51
52
53
54
55
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

56
57
    # AudioFlamingo3 uses <sound> token for audio
    audio_placeholder = "<sound>" * audio_count
Patrick von Platen's avatar
Patrick von Platen committed
58

59
60
61
62
63
64
65
    prompt = (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
Patrick von Platen's avatar
Patrick von Platen committed
66
67
68

    return ModelRequestData(
        engine_args=engine_args,
69
        prompt=prompt,
Patrick von Platen's avatar
Patrick von Platen committed
70
71
72
    )


Ekagra Ranjan's avatar
Ekagra Ranjan committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
    assert audio_count == 1, "CohereASR only support single audio input per prompt"
    # TODO (ekagra): add HF ckpt after asr release
    model_name = "/host/engines/vllm/audio/2b-release"

    prompt = (
        "<|startofcontext|><|startoftranscript|>"
        "<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
        "<|notimestamp|><|nodiarize|>"
    )
    engine_args = EngineArgs(
        model=model_name,
        limit_mm_per_prompt={"audio": audio_count},
        trust_remote_code=True,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


96
97
98
99
100
101
102
103
104
105
106
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/music-flamingo-2601-hf"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

107
108
109
    # MusicFlamingo prompt placeholders use <sound>; vLLM's MusicFlamingo
    # multimodal processor expands each one into <|sound_bos|> + audio tokens +
    # <|sound_eos|> based on extracted audio feature lengths.
110
    audio_placeholder = "<sound>" * audio_count
111
112
113
114
115
116
117
118
    system_prompt = (
        "You are Music Flamingo, a multimodal assistant for language and music. "
        "On each turn you receive an audio clip which contains music and optional "
        "text, you will receive at least one or both; use your world knowledge and "
        "reasoning to help the user with any task. Interpret the entirety of the "
        "content any input music--regardlenss of whether the user calls it audio, "
        "music, or sound."
    )
119
120
121

    prompt = (
        "<|im_start|>system\n"
122
        f"{system_prompt}<|im_end|>\n"
123
124
125
126
127
128
129
130
131
132
133
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
    model_name = "google/gemma-3n-E2B-it"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_batched_tokens=2048,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )
    prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
    "<end_of_turn>\n<start_of_turn>model\n"
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "zai-org/GLM-ASR-Nano-2512"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # GLM-ASR uses <|pad|> token for audio
    audio_placeholder = "<|pad|>" * audio_count

    messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
    # NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
    # writing. Pass a local model path via `--model`.
    model_name = "funaudiochat"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
    )
    prompt = f"{audio_in_prompt}{question}"

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


206
207
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
208
    # NOTE - the setting in this example are somewhat different from what is
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
    """Kimi-Audio-7B-Instruct for audio transcription and understanding."""
    model_name = "moonshotai/Kimi-Audio-7B-Instruct"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
    audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
    # Default prompt for transcription
    if not question:
        question = "Please transcribe the audio"
    prompt = f"{audio_placeholder}{question}"

    # Stop at EOS token (151644) to prevent repetition
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=[151644],
    )


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
    model_name = "mispeech/midashenglm-7b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
    )

    default_system = "You are a helpful language and speech assistant."

    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


295
# MiniCPM-O
296
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
297
    model_name = "openbmb/MiniCPM-o-2_6"
298
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
299
300
301
302
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
303
        max_num_seqs=2,
304
305
        limit_mm_per_prompt={"audio": audio_count},
    )
306

307
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
308
309
310
311
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
312
313
314
315
316
317
318
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
319
320
321
322
323
324

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
325
326


327
# Phi-4-multimodal-instruct
328
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
329
330
331
332
333
334
335
336
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
337
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
338

339
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
340

341
    engine_args = EngineArgs(
342
343
        model=model_path,
        trust_remote_code=True,
344
        max_model_len=12800,
345
346
347
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
348
        limit_mm_per_prompt={"audio": audio_count},
349
350
    )

351
352
353
354
355
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
356
357


358
# Qwen2-Audio
359
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
360
361
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

362
363
364
365
366
367
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
368

369
370
371
372
373
374
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
375

376
377
378
379
380
381
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
382
383
384
385
386

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
387
388


389
390
391
392
393
394
395
396
397
398
399
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

400
401
402
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
403
404
405
406

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
407
408
        "generating text and speech."
    )
409

410
411
412
413
414
415
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
416
417
418
419
420
421
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Roger Wang's avatar
Roger Wang committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "Qwen/Qwen3-Asr-1.7B"

    audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
    prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


441
# Ultravox 0.5-1B
442
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
443
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
444

445
    tokenizer = AutoTokenizer.from_pretrained(model_name)
446
447
448
449
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
450

451
452
453
454
455
456
457
458
459
460
461
462
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
463
464


465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.chunk import (
        AudioChunk,
        RawAudio,
        TextChunk,
    )
    from mistral_common.protocol.instruct.messages import (
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


522
# Whisper
523
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
524
    assert audio_count == 1, "Whisper only support single audio input per prompt"
525
526
527
528
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

529
530
531
532
533
534
535
536
537
538
539
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
540
541
542


model_example_map = {
543
    "audioflamingo3": run_audioflamingo3,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
544
545
    "cohere_asr": run_cohere_asr,
    "funaudiochat": run_funaudiochat,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
546
    "gemma3n": run_gemma3n,
547
    "glmasr": run_glmasr,
548
    "granite_speech": run_granite_speech,
549
    "kimi_audio": run_kimi_audio,
550
    "midashenglm": run_midashenglm,
551
    "minicpmo": run_minicpmo,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
552
    "musicflamingo": run_musicflamingo,
553
    "phi4_mm": run_phi4mm,
554
    "qwen2_audio": run_qwen2_audio,
555
    "qwen2_5_omni": run_qwen2_5_omni,
Roger Wang's avatar
Roger Wang committed
556
    "qwen3_asr": run_qwen3_asr,
557
    "ultravox": run_ultravox,
558
    "voxtral": run_voxtral,
559
    "whisper": run_whisper,
560
}
561
562


563
564
def parse_args():
    parser = FlexibleArgumentParser(
565
566
567
568
569
570
571
572
573
574
575
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
576
577
578
579
580
581
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model ID or local path override. Required for funaudiochat.",
    )
582
583
584
585
586
587
588
589
590
591
592
593
594
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
595
        default=0,
596
597
        help="Set the seed when initializing `vllm.LLM`.",
    )
598
599
600
601
602
603
604
    parser.add_argument(
        "--tensor-parallel-size",
        "-tp",
        type=int,
        default=None,
        help="Tensor parallel size to override the model's default setting. ",
    )
605
606
607
608

    return parser.parse_args()


609
610
611
612
613
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

614
615
616
    if model == "funaudiochat" and not args.model:
        raise ValueError("--model is required when --model-type=funaudiochat")

617
618
619
620
621
622
    if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
        raise ValueError(
            f"tensor_parallel_size must be a positive integer, "
            f"got {args.tensor_parallel_size}"
        )

623
    audio_count = args.num_audios
624
625
626
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
627
628
    if model == "funaudiochat":
        req_data.engine_args.model = args.model
629

630
631
632
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
633
634
        req_data.engine_args.limit_mm_per_prompt or {}
    )
635

636
    engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
637
638
    if args.tensor_parallel_size is not None:
        engine_args["tensor_parallel_size"] = args.tensor_parallel_size
639
640
    llm = LLM(**engine_args)

641
642
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
643
644
645
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
646

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def get_input(start, end):
        mm_data = req_data.multi_modal_data
        if not mm_data:
            mm_data = {}
            if end - start > 0:
                mm_data = {
                    "audio": [
                        asset.audio_and_sample_rate for asset in audio_assets[start:end]
                    ]
                }

        inputs = {"multi_modal_data": mm_data}

        if req_data.prompt:
            inputs["prompt"] = req_data.prompt
        else:
            inputs["prompt_token_ids"] = req_data.prompt_token_ids

        return inputs

    # Batch inference
668
    assert args.num_prompts > 0
669
670
671
    if audio_count != 1:
        inputs = get_input(0, audio_count)
        inputs = [inputs] * args.num_prompts
Patrick von Platen's avatar
Patrick von Platen committed
672
    else:
673
674
675
676
677
678
679
        # For single audio input, we need to vary the audio input
        # to avoid deduplication in vLLM engine.
        inputs = []
        for i in range(args.num_prompts):
            start = i % len(audio_assets)
            inp = get_input(start, start + 1)
            inputs.append(inp)
Patrick von Platen's avatar
Patrick von Platen committed
680

681
    # Add LoRA request if applicable
682
683
684
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
685
686
687
688
689
690

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
691
692
693
694
695
696
697

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
698
    args = parse_args()
699
    main(args)