test_voxtral.py 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import pytest
import pytest_asyncio
from mistral_common.audio import Audio
9
10
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
11

12
from vllm.tokenizers import MistralTokenizer
13
14
15
16
17
18
19

from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test

MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS = [
20
21
22
23
24
25
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
26
27
28
29
30
31
32
33
34
35
36
]


@pytest.fixture()
def server(request, audio_assets: AudioTestAssets):
    args = [
        "--enforce-eager",
        "--limit-mm-per-prompt",
        json.dumps({"audio": len(audio_assets)}),
    ] + MISTRAL_FORMAT_ARGS

37
38
39
    with RemoteOpenAIServer(
        MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"}
    ) as remote_server:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


def _get_prompt(audio_assets, question):
    tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)

    audios = [
        Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
        for i in range(len(audio_assets))
    ]
    audio_chunks = [
        AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
    ]

    text_chunk = TextChunk(text=question)
    messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]

    return tokenizer.apply_chat_template(messages=messages)


@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
70
71
72
73
74
75
76
def test_models_with_multiple_audios(
    vllm_runner,
    audio_assets: AudioTestAssets,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
77
78
79
    vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
    run_multi_audio_test(
        vllm_runner,
80
        [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])],
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        MODEL_NAME,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        tokenizer_mode="mistral",
    )


@pytest.mark.asyncio
async def test_online_serving(client, audio_assets: AudioTestAssets):
    """Exercises online serving with/without chunked prefill enabled."""

    def asset_to_chunk(asset):
        audio = Audio.from_file(str(asset.get_local_path()), strict=False)
        audio.format = "wav"
        audio_dict = AudioChunk.from_audio(audio).to_openai()
        return audio_dict

    audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
100
    text = f"What's happening in these {len(audio_assets)} audio clips?"
101
102
103
    messages = [
        {
            "role": "user",
104
            "content": [*audio_chunks, {"type": "text", "text": text}],
105
106
107
108
109
110
        }
    ]

    chat_completion = await client.chat.completions.create(
        model=MODEL_NAME, messages=messages, max_tokens=10
    )
111
112
113
114

    assert len(chat_completion.choices) == 1
    choice = chat_completion.choices[0]
    assert choice.finish_reason == "length"