test_voxtral_realtime.py 6.15 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
5
6
from dataclasses import asdict

import pytest
7
import pytest_asyncio
8
9
10
11
12
13
14
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
    StreamingMode,
    TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
15
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
16
17
18

from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
19
from vllm.engine.arg_utils import AsyncEngineArgs
20
from vllm.v1.engine.async_llm import AsyncLLM
21

22
23
from ....utils import ROCM_ENGINE_KWARGS

24
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
25
26
27
28
29
30
31
32
33
34
35
36
ENGINE_CONFIG = {
    "model": MODEL_NAME,
    "max_model_len": 8192,
    "max_num_seqs": 4,
    "limit_mm_per_prompt": {"audio": 1},
    "config_format": "mistral",
    "load_format": "mistral",
    "tokenizer_mode": "mistral",
    "enforce_eager": True,
    "gpu_memory_utilization": 0.9,
    **ROCM_ENGINE_KWARGS,
}
37

38
39
40
41
42

EXPECTED_TEXT = [
    (
        " First words I spoke in the original phonograph. "
        "A little piece of practical poetry. Mary had a little lamb,"
43
        " its fleece was quite a slow, and everywhere that Mary went, "
44
45
46
47
48
49
50
51
52
53
54
55
56
        "the lamb was sure to go."
    ),
    (
        " And the 0-1 pitch on the way to Edgar Martinez. Swung on"
        " the line. Down the left field line for OBS. Here comes Joy. "
        "Here is Junior to third base. They're going to wave him in. "
        "The throw to the plate will be late. The Mariners are going"
        " to play. For the American League Championship, "
        "I don't believe it. It just continues. My, oh, my."
    ),
]


57
58
59
60
61
62
63
64
def _normalize(texts: list[str]) -> list[str]:
    # The model occasionally transcribes "OBS" as "a base hit" and
    # "oh, my" as "oh my", but both are acoustically valid. Normalise so
    # the assertion is stable across runs and hardware.
    texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
    return texts


65
66
67
68
69
70
71
72
73
74
75
@pytest.fixture
def audio_assets() -> list[AudioAsset]:
    return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]


@pytest.fixture
def tokenizer() -> MistralTokenizer:
    return MistralTokenizer.from_hf_hub(MODEL_NAME)


@pytest.fixture
76
def engine():
77
    engine_args = EngineArgs(**ENGINE_CONFIG)
78
79
80
81
82
83
84
    llm = LLM(**asdict(engine_args))
    try:
        yield llm
    finally:
        with contextlib.suppress(Exception):
            llm.llm_engine.engine_core.shutdown()
        import torch
85

86
        torch.accelerator.empty_cache()
87

88
89
90

@pytest_asyncio.fixture
async def async_engine():
91
    engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
92
93
94
95
96
    llm = AsyncLLM.from_engine_args(engine_args)
    try:
        yield llm
    finally:
        llm.shutdown()
97

98

99
def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    audio_config = tokenizer.instruct_tokenizer.tokenizer.audio

    def from_file(file_path: str):
        audio = Audio.from_file(file_path, strict=False)
        req = TranscriptionRequest(
            audio=RawAudio.from_audio(audio),
            streaming=StreamingMode.OFFLINE,
            language=None,
        )
        tokenized = tokenizer.instruct_tokenizer.encode_transcription(req)

        return (tokenized.tokens, tokenized.audios[0].audio_array)

    tokenized_list = [
        from_file(audio_asset.get_local_path()) for audio_asset in audio_assets
    ]

    inputs = []
    sampling_params = []

    for tokens, audio_array in tokenized_list:
        num_samples = audio_array.shape[0]
122
        max_tokens = audio_config.num_audio_tokens(num_samples) - len(tokens) - 1
123
124
125
126
127
128
129
130
        sampling_params.append(SamplingParams(temperature=0.0, max_tokens=max_tokens))

        input_dict = {
            "multi_modal_data": {"audio": [(audio_array, None)]},
            "prompt_token_ids": tokens,
        }
        inputs.append(input_dict)

131
    outputs = engine.generate(
132
133
134
135
        inputs,
        sampling_params=sampling_params,
    )

136
137
138
139
140
141
142
    texts = _normalize([out.outputs[0].text for out in outputs])
    for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
        assert got == expected, (
            f"Output mismatch at index {i}:\n"
            f"  got:      {got!r}\n"
            f"  expected: {expected!r}"
        )
143
144
145


@pytest.mark.asyncio
146
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
147
148
149
    # Lazy import to avoid CUDA-reinitialization error
    from vllm.model_executor.models.voxtral_realtime import VoxtralRealtimeBuffer

150
    sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
151
    audio_config = tokenizer.instruct_tokenizer.audio_encoder.audio_config
152
153
154
155
156

    output_tokens_list = []
    for i, audio_asset in enumerate(audio_assets):
        output_tokens = []
        audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
157
158
159
160
161

        req = TranscriptionRequest(
            streaming=StreamingMode.OFFLINE,
            audio=RawAudio.from_audio(audio),
            language=None,
162
        )
163
164
165
166
167
        audio_enc = tokenizer.encode_transcription(req)

        buffer = VoxtralRealtimeBuffer(audio_config, audio_enc.tokens)
        await buffer.append_audio(audio_enc.audios[0].audio_array)
        await buffer.append_audio(None)
168
169
170
171

        request_id = f"session-{i}"

        async for resp in async_engine.generate(
172
            prompt=buffer.get_input_stream(),
173
174
175
176
177
            sampling_params=sampling_params,
            request_id=request_id,
        ):
            tokens = resp.outputs[0].token_ids[-1:]
            output_tokens.extend(tokens)
178
            await buffer.append_tokens(tokens)
179
180
181

        output_tokens_list.append(output_tokens)

182
183
184
185
186
187
188
189
190
191
192
193
194
195
    texts = _normalize(
        [
            tokenizer.decode(
                output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE
            )
            for output_tokens in output_tokens_list
        ]
    )
    for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
        assert got == expected, (
            f"Output mismatch at index {i}:\n"
            f"  got:      {got!r}\n"
            f"  expected: {expected!r}"
        )