test_translation_validation.py 7.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5

6
# imports for structured outputs tests
7
8
import json

9
import httpx
10
11
12
import librosa
import numpy as np
import pytest
13
import pytest_asyncio
14
15
16
17
import soundfile as sf

from ...utils import RemoteOpenAIServer

18
19
SERVER_ARGS = ["--enforce-eager"]

20

21
22
23
@pytest.fixture(
    scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
)
24
25
26
27
def server(request):
    # Parametrize over model name
    with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
        yield remote_server, request.param
28
29


30
@pytest_asyncio.fixture
31
32
async def client_and_model(server):
    server, model_name = server
33
    async with server.get_async_client() as async_client:
34
        yield async_client, model_name
35
36
37
38
39
40


@pytest.mark.asyncio
async def test_non_asr_model(foscolo):
    # text to text model
    model_name = "JackFram/llama-68m"
41
    with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
42
        client = remote_server.get_async_client()
43
44
45
        res = await client.audio.translations.create(
            model=model_name, file=foscolo, temperature=0.0
        )
46
47
48
        err = res.error
        assert err["code"] == 400 and not res.text
        assert err["message"] == "The model does not support Translations API"
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb):
    """Ensure STT (translate) requests can pass LoRA through to generate."""
    # NOTE - careful to call this test before the module scoped server
    # fixture, otherwise it'll OOMkill the CI
    model_name = "ibm-granite/granite-speech-3.3-2b"
    lora_model_name = "speech"
    server_args = [
        "--enforce-eager",
        "--enable-lora",
        "--max-lora-rank",
        "64",
        "--lora-modules",
        f"{lora_model_name}={model_name}",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "1",
    ]

    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        translation = await client.audio.translations.create(
            model=lora_model_name,
            file=mary_had_lamb,
            extra_body=dict(language="en", to_language="es"),
            response_format="text",
            temperature=0.0,
        )
    out = json.loads(translation)["text"].strip().lower()
    assert "mary tenía un pequeño cordero" in out


85
86
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
87
88
async def test_basic_audio(foscolo, client_and_model):
    client, model_name = client_and_model
89
    translation = await client.audio.translations.create(
90
        model=model_name,
91
92
        file=foscolo,
        response_format="text",
93
94
        # TODO remove `language="it"` once language detection is implemented
        extra_body=dict(language="it", to_language="en"),
95
96
97
        temperature=0.0,
    )
    out = json.loads(translation)["text"].strip().lower()
98
99
100
101
    assert "greek sea" in out


@pytest.mark.asyncio
102
103
async def test_audio_prompt(foscolo, client_and_model):
    client, model_name = client_and_model
104
105
106
    # Condition whisper on starting text
    prompt = "Nor have I ever"
    transcription = await client.audio.translations.create(
107
        model=model_name,
108
109
        file=foscolo,
        prompt=prompt,
110
        extra_body=dict(language="it", to_language="en"),
111
        response_format="text",
112
113
114
        temperature=0.0,
    )
    out = json.loads(transcription)["text"]
115
116
117
118
    assert "Nor will I ever touch the sacred" not in out
    assert prompt not in out


119
@pytest.mark.asyncio
120
121
async def test_streaming_response(foscolo, client_and_model, server):
    client, model_name = client_and_model
122
    translation = ""
123
    res_no_stream = await client.audio.translations.create(
124
        model=model_name,
125
126
        file=foscolo,
        response_format="json",
127
        extra_body=dict(language="it", to_language="en", seed=42),
128
129
        temperature=0.0,
    )
130

131
    # Stream via HTTPX since OpenAI translation client doesn't expose streaming
132
    server, model_name = server
133
134
135
    url = server.url_for("v1/audio/translations")
    headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
    data = {
136
        "model": model_name,
137
        "language": "it",
138
        "to_language": "en",
139
140
        "stream": True,
        "temperature": 0.0,
141
        "seed": 42,
142
143
144
145
    }
    foscolo.seek(0)
    async with httpx.AsyncClient() as http_client:
        files = {"file": foscolo}
146
147
148
        async with http_client.stream(
            "POST", url, headers=headers, data=data, files=files
        ) as response:
149
150
151
152
            async for line in response.aiter_lines():
                if not line:
                    continue
                if line.startswith("data: "):
153
                    line = line[len("data: ") :]
154
155
156
157
158
159
                if line.strip() == "[DONE]":
                    break
                chunk = json.loads(line)
                text = chunk["choices"][0].get("delta", {}).get("content")
                translation += text or ""

160
161
162
163
    res_stream = translation.split()
    # NOTE There's a small non-deterministic issue here, likely in the attn
    # computation, which will cause a few tokens to be different, while still
    # being very close semantically.
164
165
166
167
    assert (
        sum([x == y for x, y in zip(res_stream, res_no_stream.text.split())])
        >= len(res_stream) * 0.9
    )
168
169
170


@pytest.mark.asyncio
171
172
async def test_stream_options(foscolo, server):
    server, model_name = server
173
174
175
    url = server.url_for("v1/audio/translations")
    headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
    data = {
176
        "model": model_name,
177
        "language": "it",
178
        "to_language": "en",
179
180
181
182
183
184
185
186
187
188
        "stream": True,
        "stream_include_usage": True,
        "stream_continuous_usage_stats": True,
        "temperature": 0.0,
    }
    foscolo.seek(0)
    final = False
    continuous = True
    async with httpx.AsyncClient() as http_client:
        files = {"file": foscolo}
189
190
191
        async with http_client.stream(
            "POST", url, headers=headers, data=data, files=files
        ) as response:
192
193
194
195
            async for line in response.aiter_lines():
                if not line:
                    continue
                if line.startswith("data: "):
196
                    line = line[len("data: ") :]
197
198
199
200
201
                if line.strip() == "[DONE]":
                    break
                chunk = json.loads(line)
                choices = chunk.get("choices", [])
                if not choices:
202
203
204
                    # final usage sent
                    final = True
                else:
205
206
                    continuous = continuous and ("usage" in chunk)
    assert final and continuous
207
208
209


@pytest.mark.asyncio
210
211
212
213
async def test_long_audio_request(foscolo, client_and_model):
    client, model_name = client_and_model
    if model_name == "google/gemma-3n-E2B-it":
        pytest.skip("Gemma3n does not support long audio requests")
214
215
216
217
218
    foscolo.seek(0)
    audio, sr = librosa.load(foscolo)
    repeated_audio = np.tile(audio, 2)
    # Repeated audio to buffer
    buffer = io.BytesIO()
219
    sf.write(buffer, repeated_audio, sr, format="WAV")
220
    buffer.seek(0)
221
    translation = await client.audio.translations.create(
222
        model=model_name,
223
        file=buffer,
224
        extra_body=dict(language="it", to_language="en"),
225
        response_format="text",
226
227
228
        temperature=0.0,
    )
    out = json.loads(translation)["text"].strip().lower()
229
    assert out.count("greek sea") == 2