test_transcription_validation.py 8.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
9
10
11
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
import soundfile as sf

from ...utils import RemoteOpenAIServer

17
18
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
19
20
21
22
23
MISTRAL_FORMAT_ARGS = [
    "--tokenizer_mode", "mistral", "--config_format", "mistral",
    "--load_format", "mistral"
]

24

25
26
27
28
29
30
31
32
33
34
35
36
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


37
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
38
39
40
41
@pytest.mark.parametrize(
    "model_name",
    ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
42
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
43
44
45
46

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

47
48
49
50
51
52
53
54
55
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
            temperature=0.0)
56
57
58
59
60
        out = json.loads(transcription)
        out_text = out['text']
        out_usage = out['usage']
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
61
62


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
    model_name = "google/gemma-3n-E2B-it"
    server_args = ["--enforce-eager"]

    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=foscolo,
            language="it",
            response_format="text",
            temperature=0.0)
        out = json.loads(transcription)['text']
        assert "da cui vergine nacque Venere" in out


82
@pytest.mark.asyncio
83
84
85
86
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
87
        client = remote_server.get_async_client()
88
89
90
91
92
93
94
95
        res = await client.audio.transcriptions.create(model=model_name,
                                                       file=winning_call,
                                                       language="en",
                                                       temperature=0.0)
        err = res.error
        assert err["code"] == 400 and not res.text
        assert err[
            "message"] == "The model does not support Transcriptions API"
96

97
98

@pytest.mark.asyncio
99
100
101
102
103
104
105
106
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
        await client.audio.transcriptions.create(model=MODEL_NAME,
                                                 file=mary_had_lamb,
                                                 language="hh",
                                                 temperature=0.0)

107

108
109
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
110
111
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
112
113
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
114
115
116
117
118
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
    sf.write(buffer, repeated_audio, sr, format='WAV')
    buffer.seek(0)
119
120
121
122
123
124
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
        temperature=0.0)
125
126
127
128
    out = json.loads(transcription)
    out_text = out['text']
    out_usage = out['usage']
    counts = out_text.count("Mary had a little lamb")
129
    assert counts == 10, counts
130
    assert out_usage["seconds"] == 161, out_usage["seconds"]
131
132
133


@pytest.mark.asyncio
134
async def test_completion_endpoints(client):
135
    # text to text model
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    res = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{
            "role": "system",
            "content": "You are a helpful assistant."
        }])
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
150
151
152


@pytest.mark.asyncio
153
async def test_streaming_response(winning_call, client):
154
    transcription = ""
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
        temperature=0.0)
    res = await client.audio.transcriptions.create(model=MODEL_NAME,
                                                   file=winning_call,
                                                   language="en",
                                                   temperature=0.0,
                                                   stream=True,
                                                   timeout=30)
    # Reconstruct from chunks and validate
    async for chunk in res:
        text = chunk.choices[0]['delta']['content']
        transcription += text

    assert transcription == res_no_stream.text
173
174
175


@pytest.mark.asyncio
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        extra_body=dict(stream_include_usage=True,
                        stream_continuous_usage_stats=True),
        timeout=30)
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
            continuous = continuous and hasattr(chunk, 'usage')
    assert final and continuous
195
196
197


@pytest.mark.asyncio
198
async def test_sampling_params(mary_had_lamb, client):
199
200
201
202
    """
    Compare sampling with params and greedy sampling to assert results
    are different when extreme sampling parameters values are picked. 
    """
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
        extra_body=dict(seed=42,
                        repetition_penalty=1.9,
                        top_k=12,
                        top_p=0.4,
                        min_p=0.5,
                        frequency_penalty=1.8,
                        presence_penalty=2.0))

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
        extra_body=dict(seed=42))

    assert greedy_transcription.text != transcription.text
224
225
226


@pytest.mark.asyncio
227
async def test_audio_prompt(mary_had_lamb, client):
228
    prompt = "This is a speech, recorded in a phonograph."
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    #Prompts should not omit the part of original prompt while transcribing.
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        temperature=0.0)
    out = json.loads(transcription)['text']
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
        temperature=0.0)
    out_prompt = json.loads(transcription_wprompt)['text']
    assert prefix in out_prompt