test_voxtral_realtime.py 8.12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
6
7
8
9
10
11
12
from dataclasses import asdict

import pytest
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
    StreamingMode,
    TranscriptionRequest,
)
13
from mistral_common.tokens.tokenizers.audio import AudioConfig
14
15
16
17
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
18
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput
21

22
23
24
25
26
27
28
29
30
31
32
33
MODEL_NAME = "mistralai/Voxtral-Mini-3B-Realtime-2602"
ENGINE_CONFIG = dict(
    model=MODEL_NAME,
    max_model_len=8192,
    max_num_seqs=4,
    limit_mm_per_prompt={"audio": 1},
    config_format="mistral",
    load_format="mistral",
    tokenizer_mode="mistral",
    enforce_eager=True,
    gpu_memory_utilization=0.4,
)
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

EXPECTED_TEXT = [
    (
        " First words I spoke in the original phonograph. "
        "A little piece of practical poetry. Mary had a little lamb,"
        " it sleeps with quite a snow, and everywhere that Mary went, "
        "the lamb was sure to go."
    ),
    (
        " And the 0-1 pitch on the way to Edgar Martinez. Swung on"
        " the line. Down the left field line for OBS. Here comes Joy. "
        "Here is Junior to third base. They're going to wave him in. "
        "The throw to the plate will be late. The Mariners are going"
        " to play. For the American League Championship, "
        "I don't believe it. It just continues. My, oh, my."
    ),
]


@pytest.fixture
def audio_assets() -> list[AudioAsset]:
    return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]


@pytest.fixture
def tokenizer() -> MistralTokenizer:
    return MistralTokenizer.from_hf_hub(MODEL_NAME)


@pytest.fixture
def engine() -> LLM:
    engine_args = EngineArgs(**ENGINE_CONFIG)
67
68
69
    return LLM(**asdict(engine_args))


70
71
72
73
@pytest.fixture
def async_engine() -> AsyncLLM:
    engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
    return AsyncLLM.from_engine_args(engine_args)
74

75
76

@pytest.mark.skip(reason="Voxtral streaming is not yet public")
77
def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    audio_config = tokenizer.instruct_tokenizer.tokenizer.audio

    def from_file(file_path: str):
        audio = Audio.from_file(file_path, strict=False)
        req = TranscriptionRequest(
            audio=RawAudio.from_audio(audio),
            streaming=StreamingMode.OFFLINE,
            language=None,
        )
        tokenized = tokenizer.instruct_tokenizer.encode_transcription(req)

        return (tokenized.tokens, tokenized.audios[0].audio_array)

    tokenized_list = [
        from_file(audio_asset.get_local_path()) for audio_asset in audio_assets
    ]

    inputs = []
    sampling_params = []

    for tokens, audio_array in tokenized_list:
        num_samples = audio_array.shape[0]
100
        max_tokens = audio_config.num_audio_tokens(num_samples) - len(tokens) - 1
101
102
103
104
105
106
107
108
        sampling_params.append(SamplingParams(temperature=0.0, max_tokens=max_tokens))

        input_dict = {
            "multi_modal_data": {"audio": [(audio_array, None)]},
            "prompt_token_ids": tokens,
        }
        inputs.append(input_dict)

109
    outputs = engine.generate(
110
111
112
113
114
        inputs,
        sampling_params=sampling_params,
    )

    texts = [out.outputs[0].text for out in outputs]
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    assert texts == EXPECTED_TEXT


class RealTimeAudioInput:
    """
    This class is used to stream an audio file just as
    if it would be streamed in real-time.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        self._tokenizer = tokenizer
        self._config: AudioConfig = (
            self._tokenizer.instruct_tokenizer.audio_encoder.audio_config
        )

        self._look_ahead_in_ms = self._config.streaming_look_ahead_ms
        self._look_back_in_ms = self._config.streaming_look_back_ms

        self._sampling_rate = self._config.sampling_rate

        self._audio: Audio | None = None

        # mutable objects
        self._start = 0

        n_left_pad_samples = (
            self._config.raw_audio_length_per_tok * self._config.n_left_pad_tokens
        )
        self._end = self.streaming_delay + n_left_pad_samples + self.streaming_size
        self._queue: asyncio.Queue[StreamingInput | None] = asyncio.Queue()

    @classmethod
    async def create(cls, audio: Audio, tokenizer: MistralTokenizer):
        self = cls(tokenizer)

        # we're doing "OFFLINE" encoding here to right & left pad the audio since
        # we have access to the whole audio
        # if we'd do an actual online realtime streaming application we
        # should instead pass `StreamingMode.ONLINE`
        req = TranscriptionRequest(
            streaming=StreamingMode.OFFLINE,
            audio=RawAudio.from_audio(audio),
            language=None,
        )
        audio_enc = self._tokenizer.encode_transcription(req)
        self._audio = audio_enc.audios[0]

        # add first request
        await self.add_tokens(audio_enc.tokens)

        return self

    @property
    def look_ahead(self) -> int:
        return self._get_len_in_samples(self._look_ahead_in_ms)

    @property
    def look_back(self) -> int:
        return self._get_len_in_samples(self._look_back_in_ms)

    @property
    def streaming_delay(self) -> int:
        return self._get_len_in_samples(self._config.transcription_delay_ms)

    @property
    def streaming_size(self) -> int:
        stream_size_in_ms = 1000 / self._config.frame_rate
        return self._get_len_in_samples(stream_size_in_ms)

    def _get_len_in_samples(self, len_in_ms: float) -> int:
        _len_in_s = self._sampling_rate * len_in_ms / 1000
        assert _len_in_s.is_integer(), _len_in_s
        len_in_s = int(_len_in_s)

        return len_in_s

    async def add_tokens(self, tokens: list[int]) -> None:
        assert self._audio is not None
        if self._start >= len(self._audio.audio_array):
            self.stop()
            return

        _end = self._end + self.look_ahead
        _start = max(0, self._start - self.look_back)

        multi_modal_data = {"audio": (self._audio.audio_array[_start:_end], None)}

        prompt = TokensPrompt(
            prompt_token_ids=tokens, multi_modal_data=multi_modal_data
        )

        await self._queue.put(StreamingInput(prompt))

        # increase
        self._start = self._end
        self._end = self._end + self.streaming_size

    def stop(self):
        self._queue.put_nowait(None)

    async def generator(self):
        while (item := await self._queue.get()) is not None:
            yield item


@pytest.mark.asyncio
@pytest.mark.skip(reason="Voxtral streaming is not yet public")
222
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    sampling_params = SamplingParams(temperature=0.0, max_tokens=1)

    output_tokens_list = []
    for i, audio_asset in enumerate(audio_assets):
        output_tokens = []
        audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
        streaming_input = await RealTimeAudioInput.create(
            audio=audio, tokenizer=tokenizer
        )

        request_id = f"session-{i}"

        async for resp in async_engine.generate(
            prompt=streaming_input.generator(),
            sampling_params=sampling_params,
            request_id=request_id,
        ):
            tokens = resp.outputs[0].token_ids[-1:]

            output_tokens.extend(tokens)
            await streaming_input.add_tokens(tokens)

        output_tokens_list.append(output_tokens)

    texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]

    # 'true' streaming and 'offline' streaming differ a bit because log-mels are
    # differently noramalized
    texts[0] = (
        texts[0]
        .replace("He has f", "F")
        .replace("its fleece was quite a slow", "it sleeps with quite a snow")
    )
    texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")

    assert texts == EXPECTED_TEXT