test_transcription_validation.py 7.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
9
10
11
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
import soundfile as sf

from ...utils import RemoteOpenAIServer

17
18
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
19
MISTRAL_FORMAT_ARGS = [
20
21
22
23
24
25
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
Patrick von Platen's avatar
Patrick von Platen committed
26
27
]

28

29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


41
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
42
@pytest.mark.parametrize(
43
44
    "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]
)
Patrick von Platen's avatar
Patrick von Platen committed
45
async def test_basic_audio(mary_had_lamb, model_name):
46
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
47
48
49
50

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

51
52
53
54
55
56
57
58
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
59
60
            temperature=0.0,
        )
61
        out = json.loads(transcription)
62
63
        out_text = out["text"]
        out_usage = out["usage"]
64
65
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
66
67


68
69
70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
    model_name = "google/gemma-3n-E2B-it"
    server_args = ["--enforce-eager"]

    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=foscolo,
            language="it",
            response_format="text",
82
83
84
            temperature=0.0,
        )
        out = json.loads(transcription)["text"]
85
86
87
        assert "da cui vergine nacque Venere" in out


88
@pytest.mark.asyncio
89
90
91
92
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
93
        client = remote_server.get_async_client()
94
95
96
        res = await client.audio.transcriptions.create(
            model=model_name, file=winning_call, language="en", temperature=0.0
        )
97
98
        err = res.error
        assert err["code"] == 400 and not res.text
99
        assert err["message"] == "The model does not support Transcriptions API"
100

101
102

@pytest.mark.asyncio
103
104
105
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
106
107
108
        await client.audio.transcriptions.create(
            model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
        )
109

110

111
112
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
113
114
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
115
116
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
117
118
119
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
120
    sf.write(buffer, repeated_audio, sr, format="WAV")
121
    buffer.seek(0)
122
123
124
125
126
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
127
128
        temperature=0.0,
    )
129
    out = json.loads(transcription)
130
131
    out_text = out["text"]
    out_usage = out["usage"]
132
    counts = out_text.count("Mary had a little lamb")
133
    assert counts == 10, counts
134
    assert out_usage["seconds"] == 161, out_usage["seconds"]
135
136
137


@pytest.mark.asyncio
138
async def test_completion_endpoints(client):
139
    # text to text model
140
141
    res = await client.chat.completions.create(
        model=MODEL_NAME,
142
143
        messages=[{"role": "system", "content": "You are a helpful assistant."}],
    )
144
145
146
147
148
149
150
151
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
152
153
154


@pytest.mark.asyncio
155
async def test_streaming_response(winning_call, client):
156
    transcription = ""
157
158
159
160
161
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
162
163
164
165
166
167
168
169
170
171
        temperature=0.0,
    )
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        timeout=30,
    )
172
173
    # Reconstruct from chunks and validate
    async for chunk in res:
174
        text = chunk.choices[0]["delta"]["content"]
175
176
177
        transcription += text

    assert transcription == res_no_stream.text
178
179
180


@pytest.mark.asyncio
181
182
183
184
185
186
187
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
188
189
190
        extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
        timeout=30,
    )
191
192
193
194
195
196
197
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
198
            continuous = continuous and hasattr(chunk, "usage")
199
    assert final and continuous
200
201
202


@pytest.mark.asyncio
203
async def test_sampling_params(mary_had_lamb, client):
204
205
    """
    Compare sampling with params and greedy sampling to assert results
206
    are different when extreme sampling parameters values are picked.
207
    """
208
209
210
211
212
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
213
214
215
216
217
218
219
220
221
222
        extra_body=dict(
            seed=42,
            repetition_penalty=1.9,
            top_k=12,
            top_p=0.4,
            min_p=0.5,
            frequency_penalty=1.8,
            presence_penalty=2.0,
        ),
    )
223
224
225
226
227
228

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
229
230
        extra_body=dict(seed=42),
    )
231
232

    assert greedy_transcription.text != transcription.text
233
234
235


@pytest.mark.asyncio
236
async def test_audio_prompt(mary_had_lamb, client):
237
    prompt = "This is a speech, recorded in a phonograph."
238
    # Prompts should not omit the part of original prompt while transcribing.
239
240
241
242
243
244
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
245
246
247
        temperature=0.0,
    )
    out = json.loads(transcription)["text"]
248
249
250
251
252
253
254
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
255
256
257
        temperature=0.0,
    )
    out_prompt = json.loads(transcription_wprompt)["text"]
258
    assert prefix in out_prompt