test_transcription_validation.py 7.66 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11

# imports for guided decoding tests
import io
import json

import librosa
import numpy as np
import openai
import pytest
12
import pytest_asyncio
13
14
15
16
17
18
import soundfile as sf

from vllm.assets.audio import AudioAsset

from ...utils import RemoteOpenAIServer

19
20
MODEL_NAME = "openai/whisper-large-v3-turbo"
SERVER_ARGS = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
21
22
23
24
25
MISTRAL_FORMAT_ARGS = [
    "--tokenizer_mode", "mistral", "--config_format", "mistral",
    "--load_format", "mistral"
]

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

@pytest.fixture
def mary_had_lamb():
    path = AudioAsset('mary_had_lamb').get_local_path()
    with open(str(path), "rb") as f:
        yield f


@pytest.fixture
def winning_call():
    path = AudioAsset('winning_call').get_local_path()
    with open(str(path), "rb") as f:
        yield f


41
42
43
44
45
46
47
48
49
50
51
52
@pytest.fixture(scope="module")
def server():
    with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


53
@pytest.mark.asyncio
Patrick von Platen's avatar
Patrick von Platen committed
54
55
56
57
@pytest.mark.parametrize(
    "model_name",
    ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
58
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
59
60
61
62

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

63
64
65
66
67
68
69
70
71
72
73
74
75
76
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
            temperature=0.0)
        out = json.loads(transcription)['text']
        assert "Mary had a little lamb," in out


@pytest.mark.asyncio
77
78
79
80
async def test_non_asr_model(winning_call):
    # text to text model
    model_name = "JackFram/llama-68m"
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
81
        client = remote_server.get_async_client()
82
83
84
85
86
87
88
89
        res = await client.audio.transcriptions.create(model=model_name,
                                                       file=winning_call,
                                                       language="en",
                                                       temperature=0.0)
        err = res.error
        assert err["code"] == 400 and not res.text
        assert err[
            "message"] == "The model does not support Transcriptions API"
90

91
92

@pytest.mark.asyncio
93
94
95
96
97
98
99
100
async def test_bad_requests(mary_had_lamb, client):
    # invalid language
    with pytest.raises(openai.BadRequestError):
        await client.audio.transcriptions.create(model=MODEL_NAME,
                                                 file=mary_had_lamb,
                                                 language="hh",
                                                 temperature=0.0)

101

102
103
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, client):
104
105
    mary_had_lamb.seek(0)
    audio, sr = librosa.load(mary_had_lamb)
106
107
    # Add small silence after each audio for repeatability in the split process
    audio = np.pad(audio, (0, 1600))
108
109
110
111
112
    repeated_audio = np.tile(audio, 10)
    # Repeated audio to buffer
    buffer = io.BytesIO()
    sf.write(buffer, repeated_audio, sr, format='WAV')
    buffer.seek(0)
113
114
115
116
117
118
119
120
121
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=buffer,
        language="en",
        response_format="text",
        temperature=0.0)
    out = json.loads(transcription)['text']
    counts = out.count("Mary had a little lamb")
    assert counts == 10, counts
122
123
124


@pytest.mark.asyncio
125
async def test_completion_endpoints(client):
126
    # text to text model
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    res = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{
            "role": "system",
            "content": "You are a helpful assistant."
        }])
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Chat Completions API"

    res = await client.completions.create(model=MODEL_NAME, prompt="Hello")
    err = res.error
    assert err["code"] == 400
    assert err["message"] == "The model does not support Completions API"
141
142
143


@pytest.mark.asyncio
144
async def test_streaming_response(winning_call, client):
145
    transcription = ""
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    res_no_stream = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        response_format="json",
        language="en",
        temperature=0.0)
    res = await client.audio.transcriptions.create(model=MODEL_NAME,
                                                   file=winning_call,
                                                   language="en",
                                                   temperature=0.0,
                                                   stream=True,
                                                   timeout=30)
    # Reconstruct from chunks and validate
    async for chunk in res:
        text = chunk.choices[0]['delta']['content']
        transcription += text

    assert transcription == res_no_stream.text
164
165
166


@pytest.mark.asyncio
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
async def test_stream_options(winning_call, client):
    res = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=winning_call,
        language="en",
        temperature=0.0,
        stream=True,
        extra_body=dict(stream_include_usage=True,
                        stream_continuous_usage_stats=True),
        timeout=30)
    final = False
    continuous = True
    async for chunk in res:
        if not len(chunk.choices):
            # final usage sent
            final = True
        else:
            continuous = continuous and hasattr(chunk, 'usage')
    assert final and continuous
186
187
188


@pytest.mark.asyncio
189
async def test_sampling_params(mary_had_lamb, client):
190
191
192
193
    """
    Compare sampling with params and greedy sampling to assert results
    are different when extreme sampling parameters values are picked. 
    """
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.8,
        extra_body=dict(seed=42,
                        repetition_penalty=1.9,
                        top_k=12,
                        top_p=0.4,
                        min_p=0.5,
                        frequency_penalty=1.8,
                        presence_penalty=2.0))

    greedy_transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        temperature=0.0,
        extra_body=dict(seed=42))

    assert greedy_transcription.text != transcription.text
215
216
217


@pytest.mark.asyncio
218
async def test_audio_prompt(mary_had_lamb, client):
219
    prompt = "This is a speech, recorded in a phonograph."
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    #Prompts should not omit the part of original prompt while transcribing.
    prefix = "The first words I spoke in the original phonograph"
    transcription = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        temperature=0.0)
    out = json.loads(transcription)['text']
    assert prefix in out
    transcription_wprompt = await client.audio.transcriptions.create(
        model=MODEL_NAME,
        file=mary_had_lamb,
        language="en",
        response_format="text",
        prompt=prompt,
        temperature=0.0)
    out_prompt = json.loads(transcription_wprompt)['text']
    assert prefix in out_prompt