test_transcription_validation.py 8.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
9
10
11
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
import soundfile as sf

from ...utils import RemoteOpenAIServer

17
18
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
19
MISTRAL_FORMAT_ARGS = [
20
21
22
23
24
25
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
Patrick von Platen's avatar
Patrick von Platen committed
26
27
]

28

29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


41
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
42
@pytest.mark.parametrize(
43
44
    "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]
)
Patrick von Platen's avatar
Patrick von Platen committed
45
async def test_basic_audio(mary_had_lamb, model_name):
46
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
47
48
49
50

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

51
52
53
54
55
56
57
58
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
59
60
            temperature=0.0,
        )
61
        out = json.loads(transcription)
62
63
        out_text = out["text"]
        out_usage = out["usage"]
64
65
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
66
67


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
@pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb):
    """Ensure STT (transcribe) requests can pass LoRA through to generate."""
    model_name = "ibm-granite/granite-speech-3.3-2b"
    lora_model_name = "speech"
    server_args = [
        "--enforce-eager",
        "--enable-lora",
        "--max-lora-rank",
        "64",
        "--lora-modules",
        f"{lora_model_name}={model_name}",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "1",
    ]

    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=lora_model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
            temperature=0.0,
        )
    out = json.loads(transcription)
    out_text = out["text"]
    out_usage = out["usage"]
    assert "mary had a little lamb" in out_text
    assert out_usage["seconds"] == 16, out_usage["seconds"]


103
104
105
106
107
108
109
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
    model_name = "google/gemma-3n-E2B-it"
    server_args = ["--enforce-eager"]

110
111
112
    with RemoteOpenAIServer(
        model_name, server_args, max_wait_seconds=480
    ) as remote_server:
113
114
115
116
117
118
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=foscolo,
            language="it",
            response_format="text",
119
120
121
            temperature=0.0,
        )
        out = json.loads(transcription)["text"]
122
123
124
        assert "da cui vergine nacque Venere" in out


125
@pytest.mark.asyncio
126
127
128
129
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
130
        client = remote_server.get_async_client()
131
132
133
        res = await client.audio.transcriptions.create(
            model=model_name, file=winning_call, language="en", temperature=0.0
        )
134
135
        err = res.error
        assert err["code"] == 400 and not res.text
136
        assert err["message"] == "The model does not support Transcriptions API"
137

138
139

@pytest.mark.asyncio
140
141
142
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
143
144
145
        await client.audio.transcriptions.create(
            model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
        )
146

147

148
149
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
150
151
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
152
153
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
154
155
156
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
157
    sf.write(buffer, repeated_audio, sr, format="WAV")
158
    buffer.seek(0)
159
160
161
162
163
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
164
165
        temperature=0.0,
    )
166
    out = json.loads(transcription)
167
168
    out_text = out["text"]
    out_usage = out["usage"]
169
    counts = out_text.count("Mary had a little lamb")
170
    assert counts == 10, counts
171
    assert out_usage["seconds"] == 161, out_usage["seconds"]
172
173
174


@pytest.mark.asyncio
175
async def test_completion_endpoints(client):
176
    # text to text model
177
178
    res = await client.chat.completions.create(
        model=MODEL_NAME,
179
180
        messages=[{"role": "system", "content": "You are a helpful assistant."}],
    )
181
182
183
184
185
186
187
188
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
189
190
191


@pytest.mark.asyncio
192
async def test_streaming_response(winning_call, client):
193
    transcription = ""
194
195
196
197
198
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
199
200
201
202
203
204
205
206
207
208
        temperature=0.0,
    )
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        timeout=30,
    )
209
210
    # Reconstruct from chunks and validate
    async for chunk in res:
211
        text = chunk.choices[0]["delta"]["content"]
212
213
214
        transcription += text

    assert transcription == res_no_stream.text
215
216
217


@pytest.mark.asyncio
218
219
220
221
222
223
224
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
225
226
227
        extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
        timeout=30,
    )
228
229
230
231
232
233
234
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
235
            continuous = continuous and hasattr(chunk, "usage")
236
    assert final and continuous
237
238
239


@pytest.mark.asyncio
240
async def test_sampling_params(mary_had_lamb, client):
241
242
    """
    Compare sampling with params and greedy sampling to assert results
243
    are different when extreme sampling parameters values are picked.
244
    """
245
246
247
248
249
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
250
251
252
253
254
255
256
257
258
259
        extra_body=dict(
            seed=42,
            repetition_penalty=1.9,
            top_k=12,
            top_p=0.4,
            min_p=0.5,
            frequency_penalty=1.8,
            presence_penalty=2.0,
        ),
    )
260
261
262
263
264
265

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
266
267
        extra_body=dict(seed=42),
    )
268
269

    assert greedy_transcription.text != transcription.text
270
271
272


@pytest.mark.asyncio
273
async def test_audio_prompt(mary_had_lamb, client):
274
    prompt = "This is a speech, recorded in a phonograph."
275
    # Prompts should not omit the part of original prompt while transcribing.
276
277
278
279
280
281
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
282
283
284
        temperature=0.0,
    )
    out = json.loads(transcription)["text"]
285
286
287
288
289
290
291
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
292
293
294
        temperature=0.0,
    )
    out_prompt = json.loads(transcription_wprompt)["text"]
295
    assert prefix in out_prompt