audio_language.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from typing import Any, NamedTuple
13
14

from huggingface_hub import snapshot_download
15
16
from transformers import AutoTokenizer

17
from vllm import LLM, EngineArgs, SamplingParams
18
from vllm.assets.audio import AudioAsset
19
from vllm.lora.request import LoRARequest
20
from vllm.utils.argparse_utils import FlexibleArgumentParser
21

22
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
23
24
25
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
26
    2: "What sport and what nursery rhyme are referenced?",
27
}
28

29
30
31

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
32
33
34
35
36
    prompt: str | None = None
    prompt_token_ids: dict[str, list[int]] | None = None
    multi_modal_data: dict[str, Any] | None = None
    stop_token_ids: list[int] | None = None
    lora_requests: list[LoRARequest] | None = None
37
38


39
40
41
42
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

43

44
45
46
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/audio-flamingo-3-hf"
Patrick von Platen's avatar
Patrick von Platen committed
47
48
    engine_args = EngineArgs(
        model=model_name,
49
        max_model_len=4096,
Patrick von Platen's avatar
Patrick von Platen committed
50
51
52
53
54
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

55
56
    # AudioFlamingo3 uses <sound> token for audio
    audio_placeholder = "<sound>" * audio_count
Patrick von Platen's avatar
Patrick von Platen committed
57

58
59
60
61
62
63
64
    prompt = (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67

    return ModelRequestData(
        engine_args=engine_args,
68
        prompt=prompt,
Patrick von Platen's avatar
Patrick von Platen committed
69
70
71
    )


Ekagra Ranjan's avatar
Ekagra Ranjan committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
    assert audio_count == 1, "CohereASR only support single audio input per prompt"
    # TODO (ekagra): add HF ckpt after asr release
    model_name = "/host/engines/vllm/audio/2b-release"

    prompt = (
        "<|startofcontext|><|startoftranscript|>"
        "<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
        "<|notimestamp|><|nodiarize|>"
    )
    engine_args = EngineArgs(
        model=model_name,
        limit_mm_per_prompt={"audio": audio_count},
        trust_remote_code=True,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


95
96
97
98
99
100
101
102
103
104
105
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/music-flamingo-2601-hf"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

106
107
108
    # MusicFlamingo prompt placeholders use <sound>; vLLM's MusicFlamingo
    # multimodal processor expands each one into <|sound_bos|> + audio tokens +
    # <|sound_eos|> based on extracted audio feature lengths.
109
    audio_placeholder = "<sound>" * audio_count
110
111
112
113
114
115
116
117
    system_prompt = (
        "You are Music Flamingo, a multimodal assistant for language and music. "
        "On each turn you receive an audio clip which contains music and optional "
        "text, you will receive at least one or both; use your world knowledge and "
        "reasoning to help the user with any task. Interpret the entirety of the "
        "content any input music--regardlenss of whether the user calls it audio, "
        "music, or sound."
    )
118
119
120

    prompt = (
        "<|im_start|>system\n"
121
        f"{system_prompt}<|im_end|>\n"
122
123
124
125
126
127
128
129
130
131
132
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
    model_name = "google/gemma-3n-E2B-it"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_batched_tokens=2048,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )
    prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
    "<end_of_turn>\n<start_of_turn>model\n"
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "zai-org/GLM-ASR-Nano-2512"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # GLM-ASR uses <|pad|> token for audio
    audio_placeholder = "<|pad|>" * audio_count

    messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
    # NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
    # writing. Pass a local model path via `--model`.
    model_name = "funaudiochat"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
    )
    prompt = f"{audio_in_prompt}{question}"

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


205
206
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
207
    # NOTE - the setting in this example are somewhat different from what is
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
    """Kimi-Audio-7B-Instruct for audio transcription and understanding."""
    model_name = "moonshotai/Kimi-Audio-7B-Instruct"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
    audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
    # Default prompt for transcription
    if not question:
        question = "Please transcribe the audio"
    prompt = f"{audio_placeholder}{question}"

    # Stop at EOS token (151644) to prevent repetition
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=[151644],
    )


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
    model_name = "mispeech/midashenglm-7b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
    )

    default_system = "You are a helpful language and speech assistant."

    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


294
# MiniCPM-O
295
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
296
    model_name = "openbmb/MiniCPM-o-2_6"
297
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
298
299
300
301
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
302
        max_num_seqs=2,
303
304
        limit_mm_per_prompt={"audio": audio_count},
    )
305

306
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
307
308
309
310
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
311
312
313
314
315
316
317
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
318
319
320
321
322
323

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
324
325


326
# Phi-4-multimodal-instruct
327
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
328
329
330
331
332
333
334
335
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
336
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
337

338
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
339

340
    engine_args = EngineArgs(
341
342
        model=model_path,
        trust_remote_code=True,
343
        max_model_len=12800,
344
345
346
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
347
        limit_mm_per_prompt={"audio": audio_count},
348
349
    )

350
351
352
353
354
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
355
356


357
# Qwen2-Audio
358
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
359
360
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

361
362
363
364
365
366
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
367

368
369
370
371
372
373
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
374

375
376
377
378
379
380
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
381
382
383
384
385

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
386
387


388
389
390
391
392
393
394
395
396
397
398
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

399
400
401
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
402
403
404
405

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
406
407
        "generating text and speech."
    )
408

409
410
411
412
413
414
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
415
416
417
418
419
420
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Roger Wang's avatar
Roger Wang committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "Qwen/Qwen3-Asr-1.7B"

    audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
    prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


440
# Ultravox 0.5-1B
441
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
442
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
443

444
    tokenizer = AutoTokenizer.from_pretrained(model_name)
445
446
447
448
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
449

450
451
452
453
454
455
456
457
458
459
460
461
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
462
463


464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.chunk import (
        AudioChunk,
        RawAudio,
        TextChunk,
    )
    from mistral_common.protocol.instruct.messages import (
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


521
# Whisper
522
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
523
    assert audio_count == 1, "Whisper only support single audio input per prompt"
524
525
526
527
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

528
529
530
531
532
533
534
535
536
537
538
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
539
540
541


model_example_map = {
542
    "audioflamingo3": run_audioflamingo3,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
543
544
    "cohere_asr": run_cohere_asr,
    "funaudiochat": run_funaudiochat,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
545
    "gemma3n": run_gemma3n,
546
    "glmasr": run_glmasr,
547
    "granite_speech": run_granite_speech,
548
    "kimi_audio": run_kimi_audio,
549
    "midashenglm": run_midashenglm,
550
    "minicpmo": run_minicpmo,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
551
    "musicflamingo": run_musicflamingo,
552
    "phi4_mm": run_phi4mm,
553
    "qwen2_audio": run_qwen2_audio,
554
    "qwen2_5_omni": run_qwen2_5_omni,
Roger Wang's avatar
Roger Wang committed
555
    "qwen3_asr": run_qwen3_asr,
556
    "ultravox": run_ultravox,
557
    "voxtral": run_voxtral,
558
    "whisper": run_whisper,
559
}
560
561


562
563
def parse_args():
    parser = FlexibleArgumentParser(
564
565
566
567
568
569
570
571
572
573
574
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
575
576
577
578
579
580
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model ID or local path override. Required for funaudiochat.",
    )
581
582
583
584
585
586
587
588
589
590
591
592
593
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
594
        default=0,
595
596
        help="Set the seed when initializing `vllm.LLM`.",
    )
597
598
599
600
601
602
603
    parser.add_argument(
        "--tensor-parallel-size",
        "-tp",
        type=int,
        default=None,
        help="Tensor parallel size to override the model's default setting. ",
    )
604
605
606
607

    return parser.parse_args()


608
609
610
611
612
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

613
614
615
    if model == "funaudiochat" and not args.model:
        raise ValueError("--model is required when --model-type=funaudiochat")

616
617
618
619
620
621
    if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
        raise ValueError(
            f"tensor_parallel_size must be a positive integer, "
            f"got {args.tensor_parallel_size}"
        )

622
    audio_count = args.num_audios
623
624
625
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
626
627
    if model == "funaudiochat":
        req_data.engine_args.model = args.model
628

629
630
631
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
632
633
        req_data.engine_args.limit_mm_per_prompt or {}
    )
634

635
    engine_args = vars(req_data.engine_args) | {"seed": args.seed}
636
637
    if args.tensor_parallel_size is not None:
        engine_args["tensor_parallel_size"] = args.tensor_parallel_size
638
639
    llm = LLM(**engine_args)

640
641
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
642
643
644
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
    def get_input(start, end):
        mm_data = req_data.multi_modal_data
        if not mm_data:
            mm_data = {}
            if end - start > 0:
                mm_data = {
                    "audio": [
                        asset.audio_and_sample_rate for asset in audio_assets[start:end]
                    ]
                }

        inputs = {"multi_modal_data": mm_data}

        if req_data.prompt:
            inputs["prompt"] = req_data.prompt
        else:
            inputs["prompt_token_ids"] = req_data.prompt_token_ids

        return inputs

    # Batch inference
667
    assert args.num_prompts > 0
668
669
670
    if audio_count != 1:
        inputs = get_input(0, audio_count)
        inputs = [inputs] * args.num_prompts
Patrick von Platen's avatar
Patrick von Platen committed
671
    else:
672
673
674
675
676
677
678
        # For single audio input, we need to vary the audio input
        # to avoid deduplication in vLLM engine.
        inputs = []
        for i in range(args.num_prompts):
            start = i % len(audio_assets)
            inp = get_input(start, start + 1)
            inputs.append(inp)
Patrick von Platen's avatar
Patrick von Platen committed
679

680
    # Add LoRA request if applicable
681
682
683
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
684
685
686
687
688
689

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
690
691
692
693
694
695
696

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
697
    args = parse_args()
698
    main(args)