test_transcription_validation.py 7.82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
9
10
11
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
import soundfile as sf

from ...utils import RemoteOpenAIServer

17
18
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
19
MISTRAL_FORMAT_ARGS = [
20
21
22
23
24
25
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
Patrick von Platen's avatar
Patrick von Platen committed
26
27
]

28

29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


41
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
42
@pytest.mark.parametrize(
43
44
    "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]
)
Patrick von Platen's avatar
Patrick von Platen committed
45
async def test_basic_audio(mary_had_lamb, model_name):
46
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
47
48
49
50

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

51
52
53
54
55
56
57
58
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
59
60
            temperature=0.0,
        )
61
        out = json.loads(transcription)
62
63
        out_text = out["text"]
        out_usage = out["usage"]
64
65
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
66
67


68
69
70
71
72
73
74
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
    model_name = "google/gemma-3n-E2B-it"
    server_args = ["--enforce-eager"]

75
76
77
    with RemoteOpenAIServer(
        model_name, server_args, max_wait_seconds=480
    ) as remote_server:
78
79
80
81
82
83
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=foscolo,
            language="it",
            response_format="text",
84
85
86
            temperature=0.0,
        )
        out = json.loads(transcription)["text"]
87
88
89
        assert "da cui vergine nacque Venere" in out


90
@pytest.mark.asyncio
91
92
93
94
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
95
        client = remote_server.get_async_client()
96
97
98
        res = await client.audio.transcriptions.create(
            model=model_name, file=winning_call, language="en", temperature=0.0
        )
99
100
        err = res.error
        assert err["code"] == 400 and not res.text
101
        assert err["message"] == "The model does not support Transcriptions API"
102

103
104

@pytest.mark.asyncio
105
106
107
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
108
109
110
        await client.audio.transcriptions.create(
            model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
        )
111

112

113
114
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
115
116
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
117
118
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
119
120
121
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
122
    sf.write(buffer, repeated_audio, sr, format="WAV")
123
    buffer.seek(0)
124
125
126
127
128
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
129
130
        temperature=0.0,
    )
131
    out = json.loads(transcription)
132
133
    out_text = out["text"]
    out_usage = out["usage"]
134
    counts = out_text.count("Mary had a little lamb")
135
    assert counts == 10, counts
136
    assert out_usage["seconds"] == 161, out_usage["seconds"]
137
138
139


@pytest.mark.asyncio
140
async def test_completion_endpoints(client):
141
    # text to text model
142
143
    res = await client.chat.completions.create(
        model=MODEL_NAME,
144
145
        messages=[{"role": "system", "content": "You are a helpful assistant."}],
    )
146
147
148
149
150
151
152
153
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
154
155
156


@pytest.mark.asyncio
157
async def test_streaming_response(winning_call, client):
158
    transcription = ""
159
160
161
162
163
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
164
165
166
167
168
169
170
171
172
173
        temperature=0.0,
    )
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        timeout=30,
    )
174
175
    # Reconstruct from chunks and validate
    async for chunk in res:
176
        text = chunk.choices[0]["delta"]["content"]
177
178
179
        transcription += text

    assert transcription == res_no_stream.text
180
181
182


@pytest.mark.asyncio
183
184
185
186
187
188
189
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
190
191
192
        extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
        timeout=30,
    )
193
194
195
196
197
198
199
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
200
            continuous = continuous and hasattr(chunk, "usage")
201
    assert final and continuous
202
203
204


@pytest.mark.asyncio
205
async def test_sampling_params(mary_had_lamb, client):
206
207
    """
    Compare sampling with params and greedy sampling to assert results
208
    are different when extreme sampling parameters values are picked.
209
    """
210
211
212
213
214
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
215
216
217
218
219
220
221
222
223
224
        extra_body=dict(
            seed=42,
            repetition_penalty=1.9,
            top_k=12,
            top_p=0.4,
            min_p=0.5,
            frequency_penalty=1.8,
            presence_penalty=2.0,
        ),
    )
225
226
227
228
229
230

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
231
232
        extra_body=dict(seed=42),
    )
233
234

    assert greedy_transcription.text != transcription.text
235
236
237


@pytest.mark.asyncio
238
async def test_audio_prompt(mary_had_lamb, client):
239
    prompt = "This is a speech, recorded in a phonograph."
240
    # Prompts should not omit the part of original prompt while transcribing.
241
242
243
244
245
246
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
247
248
249
        temperature=0.0,
    )
    out = json.loads(transcription)["text"]
250
251
252
253
254
255
256
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
257
258
259
        temperature=0.0,
    )
    out_prompt = json.loads(transcription_wprompt)["text"]
260
    assert prefix in out_prompt