"vscode:/vscode.git/clone" did not exist on "db2906108acdc141e8a21e390228c69b1379e3c2"
audio_language.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
This example shows how to use vLLM for running offline inference
5
with the correct prompt format on audio language models.
6
7
8
9

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
10

11
import os
12
from typing import Any, NamedTuple
13
14

from huggingface_hub import snapshot_download
15
16
from transformers import AutoTokenizer

17
from vllm import LLM, EngineArgs, SamplingParams
18
from vllm.assets.audio import AudioAsset
19
from vllm.lora.request import LoRARequest
20
from vllm.utils.argparse_utils import FlexibleArgumentParser
21

22
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
23
24
25
question_per_audio_count = {
    0: "What is 1+1?",
    1: "What is recited in the audio?",
26
    2: "What sport and what nursery rhyme are referenced?",
27
}
28

29
30
31

class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
32
33
34
35
36
    prompt: str | None = None
    prompt_token_ids: dict[str, list[int]] | None = None
    multi_modal_data: dict[str, Any] | None = None
    stop_token_ids: list[int] | None = None
    lora_requests: list[LoRARequest] | None = None
37
38


39
40
41
42
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

43

44
45
46
# AudioFlamingo3
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/audio-flamingo-3-hf"
Patrick von Platen's avatar
Patrick von Platen committed
47
48
    engine_args = EngineArgs(
        model=model_name,
49
        max_model_len=4096,
Patrick von Platen's avatar
Patrick von Platen committed
50
51
52
53
54
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

55
56
    # AudioFlamingo3 uses <sound> token for audio
    audio_placeholder = "<sound>" * audio_count
Patrick von Platen's avatar
Patrick von Platen committed
57

58
59
60
61
62
63
64
    prompt = (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67

    return ModelRequestData(
        engine_args=engine_args,
68
        prompt=prompt,
Patrick von Platen's avatar
Patrick von Platen committed
69
70
71
    )


Ekagra Ranjan's avatar
Ekagra Ranjan committed
72
73
74
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
    assert audio_count == 1, "CohereASR only support single audio input per prompt"
75
    model_name = "CohereLabs/cohere-transcribe-03-2026"
Ekagra Ranjan's avatar
Ekagra Ranjan committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    prompt = (
        "<|startofcontext|><|startoftranscript|>"
        "<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
        "<|notimestamp|><|nodiarize|>"
    )
    engine_args = EngineArgs(
        model=model_name,
        limit_mm_per_prompt={"audio": audio_count},
        trust_remote_code=True,
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


94
95
96
97
98
99
100
101
102
103
104
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
    model_name = "nvidia/music-flamingo-2601-hf"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

105
106
107
    # MusicFlamingo prompt placeholders use <sound>; vLLM's MusicFlamingo
    # multimodal processor expands each one into <|sound_bos|> + audio tokens +
    # <|sound_eos|> based on extracted audio feature lengths.
108
    audio_placeholder = "<sound>" * audio_count
109
110
111
112
113
114
115
116
    system_prompt = (
        "You are Music Flamingo, a multimodal assistant for language and music. "
        "On each turn you receive an audio clip which contains music and optional "
        "text, you will receive at least one or both; use your world knowledge and "
        "reasoning to help the user with any task. Interpret the entirety of the "
        "content any input music--regardlenss of whether the user calls it audio, "
        "music, or sound."
    )
117
118
119

    prompt = (
        "<|im_start|>system\n"
120
        f"{system_prompt}<|im_end|>\n"
121
122
123
124
125
126
127
128
129
130
131
        "<|im_start|>user\n"
        f"{audio_placeholder}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
    model_name = "google/gemma-3n-E2B-it"
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=2048,
        max_num_batched_tokens=2048,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )
    prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
    "<end_of_turn>\n<start_of_turn>model\n"
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "zai-org/GLM-ASR-Nano-2512"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # GLM-ASR uses <|pad|> token for audio
    audio_placeholder = "<|pad|>" * audio_count

    messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# FunAudioChat
def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData:
    # NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
    # writing. Pass a local model path via `--model`.
    model_name = "funaudiochat"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        enforce_eager=True,
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)]
    )
    prompt = f"{audio_in_prompt}{question}"

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


204
205
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
206
    # NOTE - the setting in this example are somewhat different from what is
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    # optimal for granite speech, and it is generally recommended to use beam
    # search. Check the model README for suggested settings.
    # https://huggingface.co/ibm-granite/granite-speech-3.3-8b
    model_name = "ibm-granite/granite-speech-3.3-8b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=2048,
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=64,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # The model has an audio-specific lora directly in its model dir;
    # it should be enabled whenever you pass audio inputs to the model.
    speech_lora_path = model_name
    audio_placeholder = "<|audio|>" * audio_count
    prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"  # noqa: E501

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
    """Kimi-Audio-7B-Instruct for audio transcription and understanding."""
    model_name = "moonshotai/Kimi-Audio-7B-Instruct"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
    )

    # Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
    audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
    # Default prompt for transcription
    if not question:
        question = "Please transcribe the audio"
    prompt = f"{audio_placeholder}{question}"

    # Stop at EOS token (151644) to prevent repetition
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=[151644],
    )


263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
    model_name = "mispeech/midashenglm-7b"

    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
    )

    default_system = "You are a helpful language and speech assistant."

    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


293
# MiniCPM-O
294
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
295
    model_name = "openbmb/MiniCPM-o-2_6"
296
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
297
298
299
300
    engine_args = EngineArgs(
        model=model_name,
        trust_remote_code=True,
        max_model_len=4096,
301
        max_num_seqs=2,
302
303
        limit_mm_per_prompt={"audio": audio_count},
    )
304

305
    stop_tokens = ["<|im_end|>", "<|endoftext|>"]
306
307
308
309
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]

    audio_placeholder = "(<audio>./</audio>)" * audio_count
    audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"  # noqa: E501
310
311
312
313
314
315
316
    messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        chat_template=audio_chat_template,
    )
317
318
319
320
321
322

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
        stop_token_ids=stop_token_ids,
    )
323
324


325
# Phi-4-multimodal-instruct
326
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
327
328
329
330
331
332
333
334
    """
    Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
    show how to process audio inputs.
    """
    model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
    # Since the vision-lora and speech-lora co-exist with the base model,
    # we have to manually specify the path of the lora weights.
    speech_lora_path = os.path.join(model_path, "speech-lora")
335
    placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
336

337
    prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
338

339
    engine_args = EngineArgs(
340
341
        model=model_path,
        trust_remote_code=True,
342
        max_model_len=12800,
343
344
345
        max_num_seqs=2,
        enable_lora=True,
        max_lora_rank=320,
346
        limit_mm_per_prompt={"audio": audio_count},
347
348
    )

349
350
351
352
353
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompts,
        lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
    )
354
355


356
# Qwen2-Audio
357
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
358
359
    model_name = "Qwen/Qwen2-Audio-7B-Instruct"

360
361
362
363
364
365
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )
366

367
368
369
370
371
372
    audio_in_prompt = "".join(
        [
            f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
            for idx in range(audio_count)
        ]
    )
373

374
375
376
377
378
379
    prompt = (
        "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
380
381
382
383
384

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
385
386


387
388
389
390
391
392
393
394
395
396
397
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
    model_name = "Qwen/Qwen2.5-Omni-7B"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

398
399
400
    audio_in_prompt = "".join(
        ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
    )
401
402
403
404

    default_system = (
        "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
        "Group, capable of perceiving auditory and visual inputs, as well as "
405
406
        "generating text and speech."
    )
407

408
409
410
411
412
413
    prompt = (
        f"<|im_start|>system\n{default_system}<|im_end|>\n"
        "<|im_start|>user\n"
        f"{audio_in_prompt}{question}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )
414
415
416
417
418
419
    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


Roger Wang's avatar
Roger Wang committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
    model_name = "Qwen/Qwen3-Asr-1.7B"

    audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
    prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


439
# Ultravox 0.5-1B
440
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
441
    model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
442

443
    tokenizer = AutoTokenizer.from_pretrained(model_name)
444
445
446
447
    messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
448

449
450
451
452
453
454
455
456
457
458
459
460
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,
        trust_remote_code=True,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
461
462


463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
    from mistral_common.audio import Audio
    from mistral_common.protocol.instruct.chunk import (
        AudioChunk,
        RawAudio,
        TextChunk,
    )
    from mistral_common.protocol.instruct.messages import (
        UserMessage,
    )
    from mistral_common.protocol.instruct.request import ChatCompletionRequest
    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

    model_name = "mistralai/Voxtral-Mini-3B-2507"
    tokenizer = MistralTokenizer.from_hf_hub(model_name)

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8192,
        max_num_seqs=2,
        limit_mm_per_prompt={"audio": audio_count},
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        enforce_eager=True,
        enable_chunked_prefill=False,
    )

    text_chunk = TextChunk(text=question)
    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(audio_count)
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    messages = [UserMessage(content=[*audio_chunks, text_chunk])]

    req = ChatCompletionRequest(messages=messages, model=model_name)

    tokens = tokenizer.encode_chat_completion(req)
    prompt_ids, audios = tokens.tokens, tokens.audios

    audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]

    multi_modal_data = {"audio": audios_and_sr}

    return ModelRequestData(
        engine_args=engine_args,
        prompt_token_ids=prompt_ids,
        multi_modal_data=multi_modal_data,
    )


520
# Whisper
521
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
522
    assert audio_count == 1, "Whisper only support single audio input per prompt"
523
524
525
526
    model_name = "openai/whisper-large-v3-turbo"

    prompt = "<|startoftranscript|>"

527
528
529
530
531
532
533
534
535
536
537
    engine_args = EngineArgs(
        model=model_name,
        max_model_len=448,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )
538
539


540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# FireRedLID
def run_fireredlid(question: str, audio_count: int) -> ModelRequestData:
    assert audio_count == 1, "FireRedLID only supports single audio input per prompt"
    model_name = "PatchyTisa/FireRedLID-vllm"

    prompt = "<sos>"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=8,
        max_num_seqs=5,
        limit_mm_per_prompt={"audio": audio_count},
    )

    return ModelRequestData(
        engine_args=engine_args,
        prompt=prompt,
    )


560
model_example_map = {
561
    "audioflamingo3": run_audioflamingo3,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
562
    "cohere_asr": run_cohere_asr,
563
    "fireredlid": run_fireredlid,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
564
    "funaudiochat": run_funaudiochat,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
565
    "gemma3n": run_gemma3n,
566
    "glmasr": run_glmasr,
567
    "granite_speech": run_granite_speech,
568
    "kimi_audio": run_kimi_audio,
569
    "midashenglm": run_midashenglm,
570
    "minicpmo": run_minicpmo,
Ekagra Ranjan's avatar
Ekagra Ranjan committed
571
    "musicflamingo": run_musicflamingo,
572
    "phi4_mm": run_phi4mm,
573
    "qwen2_audio": run_qwen2_audio,
574
    "qwen2_5_omni": run_qwen2_5_omni,
Roger Wang's avatar
Roger Wang committed
575
    "qwen3_asr": run_qwen3_asr,
576
    "ultravox": run_ultravox,
577
    "voxtral": run_voxtral,
578
    "whisper": run_whisper,
579
}
580
581


582
583
def parse_args():
    parser = FlexibleArgumentParser(
584
585
586
587
588
589
590
591
592
593
594
        description="Demo on using vLLM for offline inference with "
        "audio language models"
    )
    parser.add_argument(
        "--model-type",
        "-m",
        type=str,
        default="ultravox",
        choices=model_example_map.keys(),
        help='Huggingface "model_type".',
    )
595
596
597
598
599
600
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        help="Model ID or local path override. Required for funaudiochat.",
    )
601
602
603
604
605
606
607
608
609
610
611
612
613
    parser.add_argument(
        "--num-prompts", type=int, default=1, help="Number of prompts to run."
    )
    parser.add_argument(
        "--num-audios",
        type=int,
        default=1,
        choices=[0, 1, 2],
        help="Number of audio items per prompt.",
    )
    parser.add_argument(
        "--seed",
        type=int,
614
        default=0,
615
616
        help="Set the seed when initializing `vllm.LLM`.",
    )
617
618
619
620
621
622
623
    parser.add_argument(
        "--tensor-parallel-size",
        "-tp",
        type=int,
        default=None,
        help="Tensor parallel size to override the model's default setting. ",
    )
624
625
626
627

    return parser.parse_args()


628
629
630
631
632
def main(args):
    model = args.model_type
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

633
634
635
    if model == "funaudiochat" and not args.model:
        raise ValueError("--model is required when --model-type=funaudiochat")

636
637
638
639
640
641
    if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
        raise ValueError(
            f"tensor_parallel_size must be a positive integer, "
            f"got {args.tensor_parallel_size}"
        )

642
    audio_count = args.num_audios
643
644
645
    req_data = model_example_map[model](
        question_per_audio_count[audio_count], audio_count
    )
646
647
    if model == "funaudiochat":
        req_data.engine_args.model = args.model
648

649
650
651
    # Disable other modalities to save memory
    default_limits = {"image": 0, "video": 0, "audio": 0}
    req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
652
653
        req_data.engine_args.limit_mm_per_prompt or {}
    )
654

655
    engine_args = vars(req_data.engine_args) | {"seed": args.seed}
656
657
    if args.tensor_parallel_size is not None:
        engine_args["tensor_parallel_size"] = args.tensor_parallel_size
658
659
    llm = LLM(**engine_args)

660
661
    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
662
663
664
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
    )
665

666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    def get_input(start, end):
        mm_data = req_data.multi_modal_data
        if not mm_data:
            mm_data = {}
            if end - start > 0:
                mm_data = {
                    "audio": [
                        asset.audio_and_sample_rate for asset in audio_assets[start:end]
                    ]
                }

        inputs = {"multi_modal_data": mm_data}

        if req_data.prompt:
            inputs["prompt"] = req_data.prompt
        else:
            inputs["prompt_token_ids"] = req_data.prompt_token_ids

        return inputs

    # Batch inference
687
    assert args.num_prompts > 0
688
689
690
    if audio_count != 1:
        inputs = get_input(0, audio_count)
        inputs = [inputs] * args.num_prompts
Patrick von Platen's avatar
Patrick von Platen committed
691
    else:
692
693
694
695
696
697
698
        # For single audio input, we need to vary the audio input
        # to avoid deduplication in vLLM engine.
        inputs = []
        for i in range(args.num_prompts):
            start = i % len(audio_assets)
            inp = get_input(start, start + 1)
            inputs.append(inp)
Patrick von Platen's avatar
Patrick von Platen committed
699

700
    # Add LoRA request if applicable
701
702
703
    lora_request = (
        req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
    )
704
705
706
707
708
709

    outputs = llm.generate(
        inputs,
        sampling_params=sampling_params,
        lora_request=lora_request,
    )
710
711
712
713
714
715
716

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
717
    args = parse_args()
718
    main(args)