test_transcription_validation.py 7.89 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11

# imports for guided decoding tests
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
17
18
import soundfile as sf

from vllm.assets.audio import AudioAsset

from ...utils import RemoteOpenAIServer

19
20
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
25
MISTRAL_FORMAT_ARGS = [
    "--tokenizer_mode", "mistral", "--config_format", "mistral",
    "--load_format", "mistral"
]

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

@pytest.fixture
def mary_had_lamb():
    path = AudioAsset('mary_had_lamb').get_local_path()
    with open(str(path), "rb") as f:
        yield f


@pytest.fixture
def winning_call():
    path = AudioAsset('winning_call').get_local_path()
    with open(str(path), "rb") as f:
        yield f


41
42
43
44
45
46
47
48
49
50
51
52
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


53
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
54
55
56
57
@pytest.mark.parametrize(
    "model_name",
    ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
58
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
59
60
61
62

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

63
64
65
66
67
68
69
70
71
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
            temperature=0.0)
72
73
74
75
76
        out = json.loads(transcription)
        out_text = out['text']
        out_usage = out['usage']
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
77
78
79


@pytest.mark.asyncio
80
81
82
83
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
84
        client = remote_server.get_async_client()
85
86
87
88
89
90
91
92
        res = await client.audio.transcriptions.create(model=model_name,
                                                       file=winning_call,
                                                       language="en",
                                                       temperature=0.0)
        err = res.error
        assert err["code"] == 400 and not res.text
        assert err[
            "message"] == "The model does not support Transcriptions API"
93

94
95

@pytest.mark.asyncio
96
97
98
99
100
101
102
103
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
        await client.audio.transcriptions.create(model=MODEL_NAME,
                                                 file=mary_had_lamb,
                                                 language="hh",
                                                 temperature=0.0)

104

105
106
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
107
108
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
109
110
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
111
112
113
114
115
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
    sf.write(buffer, repeated_audio, sr, format='WAV')
    buffer.seek(0)
116
117
118
119
120
121
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
        temperature=0.0)
122
123
124
125
    out = json.loads(transcription)
    out_text = out['text']
    out_usage = out['usage']
    counts = out_text.count("Mary had a little lamb")
126
    assert counts == 10, counts
127
    assert out_usage["seconds"] == 161, out_usage["seconds"]
128
129
130


@pytest.mark.asyncio
131
async def test_completion_endpoints(client):
132
    # text to text model
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    res = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{
            "role": "system",
            "content": "You are a helpful assistant."
        }])
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
147
148
149


@pytest.mark.asyncio
150
async def test_streaming_response(winning_call, client):
151
    transcription = ""
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
        temperature=0.0)
    res = await client.audio.transcriptions.create(model=MODEL_NAME,
                                                   file=winning_call,
                                                   language="en",
                                                   temperature=0.0,
                                                   stream=True,
                                                   timeout=30)
    # Reconstruct from chunks and validate
    async for chunk in res:
        text = chunk.choices[0]['delta']['content']
        transcription += text

    assert transcription == res_no_stream.text
170
171
172


@pytest.mark.asyncio
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        extra_body=dict(stream_include_usage=True,
                        stream_continuous_usage_stats=True),
        timeout=30)
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
            continuous = continuous and hasattr(chunk, 'usage')
    assert final and continuous
192
193
194


@pytest.mark.asyncio
195
async def test_sampling_params(mary_had_lamb, client):
196
197
198
199
    """
    Compare sampling with params and greedy sampling to assert results
    are different when extreme sampling parameters values are picked. 
    """
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
        extra_body=dict(seed=42,
                        repetition_penalty=1.9,
                        top_k=12,
                        top_p=0.4,
                        min_p=0.5,
                        frequency_penalty=1.8,
                        presence_penalty=2.0))

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
        extra_body=dict(seed=42))

    assert greedy_transcription.text != transcription.text
221
222
223


@pytest.mark.asyncio
224
async def test_audio_prompt(mary_had_lamb, client):
225
    prompt = "This is a speech, recorded in a phonograph."
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    #Prompts should not omit the part of original prompt while transcribing.
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        temperature=0.0)
    out = json.loads(transcription)['text']
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
        temperature=0.0)
    out_prompt = json.loads(transcription_wprompt)['text']
    assert prefix in out_prompt