test_translation_validation.py 9.44 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5

6
# imports for structured outputs tests
7
8
import json

9
import httpx
10
11
12
import librosa
import numpy as np
import pytest
13
import pytest_asyncio
14
15
16
import soundfile as sf

from ...utils import RemoteOpenAIServer
17
from .conftest import add_attention_backend
18

19
20
SERVER_ARGS = ["--enforce-eager"]

21

22
23
24
25
26
27
28
def _get_server_args(attention_config):
    """Get server args with attention backend if specified."""
    args = SERVER_ARGS.copy()
    add_attention_backend(args, attention_config)
    return args


29
30
31
@pytest.fixture(
    scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
)
32
def server(request, rocm_aiter_fa_attention):
33
    # Parametrize over model name
34
35
36
    with RemoteOpenAIServer(
        request.param, _get_server_args(rocm_aiter_fa_attention)
    ) as remote_server:
37
        yield remote_server, request.param
38
39


40
@pytest_asyncio.fixture
41
42
async def client_and_model(server):
    server, model_name = server
43
    async with server.get_async_client() as async_client:
44
        yield async_client, model_name
45
46
47


@pytest.mark.asyncio
48
async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
49
50
    # text to text model
    model_name = "JackFram/llama-68m"
51
52
53
    with RemoteOpenAIServer(
        model_name, _get_server_args(rocm_aiter_fa_attention)
    ) as remote_server:
54
        client = remote_server.get_async_client()
55
56
57
        res = await client.audio.translations.create(
            model=model_name, file=foscolo, temperature=0.0
        )
58
59
60
        err = res.error
        assert err["code"] == 400 and not res.text
        assert err["message"] == "The model does not support Translations API"
61
62


63
@pytest.mark.asyncio
64
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
65
    """Ensure STT (translate) requests can pass LoRA through to generate."""
66
67
68
69
70
    # ROCm SPECIFIC CONFIGURATION:
    # To ensure the test passes on ROCm, we modify the max model length to 512.
    # We DO NOT apply this to other platforms to maintain strict upstream parity.
    from vllm.platforms import current_platform

71
72
73
74
75
76
77
78
79
80
81
82
    # NOTE - careful to call this test before the module scoped server
    # fixture, otherwise it'll OOMkill the CI
    model_name = "ibm-granite/granite-speech-3.3-2b"
    lora_model_name = "speech"
    server_args = [
        "--enforce-eager",
        "--enable-lora",
        "--max-lora-rank",
        "64",
        "--lora-modules",
        f"{lora_model_name}={model_name}",
        "--max-model-len",
83
        "512" if current_platform.is_rocm() else "2048",
84
85
86
87
        "--max-num-seqs",
        "1",
    ]

88
89
    add_attention_backend(server_args, rocm_aiter_fa_attention)

90
91
92
93
94
95
96
97
98
99
100
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        translation = await client.audio.translations.create(
            model=lora_model_name,
            file=mary_had_lamb,
            extra_body=dict(language="en", to_language="es"),
            response_format="text",
            temperature=0.0,
        )
    out = json.loads(translation)["text"].strip().lower()
101
    assert "pequeño" in out.split(" ")
102
103


104
105
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
106
107
async def test_basic_audio(foscolo, client_and_model):
    client, model_name = client_and_model
108
    translation = await client.audio.translations.create(
109
        model=model_name,
110
111
        file=foscolo,
        response_format="text",
112
113
        # TODO remove `language="it"` once language detection is implemented
        extra_body=dict(language="it", to_language="en"),
114
115
116
        temperature=0.0,
    )
    out = json.loads(translation)["text"].strip().lower()
117
118
119
120
    assert "greek sea" in out


@pytest.mark.asyncio
121
122
async def test_audio_prompt(foscolo, client_and_model):
    client, model_name = client_and_model
123
124
125
    # Condition whisper on starting text
    prompt = "Nor have I ever"
    transcription = await client.audio.translations.create(
126
        model=model_name,
127
128
        file=foscolo,
        prompt=prompt,
129
        extra_body=dict(language="it", to_language="en"),
130
        response_format="text",
131
132
133
        temperature=0.0,
    )
    out = json.loads(transcription)["text"]
134
135
136
137
    assert "Nor will I ever touch the sacred" not in out
    assert prompt not in out


138
@pytest.mark.asyncio
139
140
async def test_streaming_response(foscolo, client_and_model, server):
    client, model_name = client_and_model
141
    translation = ""
142
    res_no_stream = await client.audio.translations.create(
143
        model=model_name,
144
145
        file=foscolo,
        response_format="json",
146
        extra_body=dict(language="it", to_language="en", seed=42),
147
148
        temperature=0.0,
    )
149

150
    # Stream via HTTPX since OpenAI translation client doesn't expose streaming
151
    server, model_name = server
152
153
154
    url = server.url_for("v1/audio/translations")
    headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
    data = {
155
        "model": model_name,
156
        "language": "it",
157
        "to_language": "en",
158
159
        "stream": True,
        "temperature": 0.0,
160
        "seed": 42,
161
162
163
164
    }
    foscolo.seek(0)
    async with httpx.AsyncClient() as http_client:
        files = {"file": foscolo}
165
166
167
        async with http_client.stream(
            "POST", url, headers=headers, data=data, files=files
        ) as response:
168
169
170
171
            async for line in response.aiter_lines():
                if not line:
                    continue
                if line.startswith("data: "):
172
                    line = line[len("data: ") :]
173
174
175
176
177
178
                if line.strip() == "[DONE]":
                    break
                chunk = json.loads(line)
                text = chunk["choices"][0].get("delta", {}).get("content")
                translation += text or ""

179
180
181
182
    res_stream = translation.split()
    # NOTE There's a small non-deterministic issue here, likely in the attn
    # computation, which will cause a few tokens to be different, while still
    # being very close semantically.
183
184
185
186
    assert (
        sum([x == y for x, y in zip(res_stream, res_no_stream.text.split())])
        >= len(res_stream) * 0.9
    )
187
188
189


@pytest.mark.asyncio
190
191
async def test_stream_options(foscolo, server):
    server, model_name = server
192
193
194
    url = server.url_for("v1/audio/translations")
    headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
    data = {
195
        "model": model_name,
196
        "language": "it",
197
        "to_language": "en",
198
199
200
201
202
203
204
205
206
207
        "stream": True,
        "stream_include_usage": True,
        "stream_continuous_usage_stats": True,
        "temperature": 0.0,
    }
    foscolo.seek(0)
    final = False
    continuous = True
    async with httpx.AsyncClient() as http_client:
        files = {"file": foscolo}
208
209
210
        async with http_client.stream(
            "POST", url, headers=headers, data=data, files=files
        ) as response:
211
212
213
214
            async for line in response.aiter_lines():
                if not line:
                    continue
                if line.startswith("data: "):
215
                    line = line[len("data: ") :]
216
217
218
219
220
                if line.strip() == "[DONE]":
                    break
                chunk = json.loads(line)
                choices = chunk.get("choices", [])
                if not choices:
221
222
223
                    # final usage sent
                    final = True
                else:
224
225
                    continuous = continuous and ("usage" in chunk)
    assert final and continuous
226
227
228


@pytest.mark.asyncio
229
230
231
232
async def test_long_audio_request(foscolo, client_and_model):
    client, model_name = client_and_model
    if model_name == "google/gemma-3n-E2B-it":
        pytest.skip("Gemma3n does not support long audio requests")
233
234
235
236
237
    foscolo.seek(0)
    audio, sr = librosa.load(foscolo)
    repeated_audio = np.tile(audio, 2)
    # Repeated audio to buffer
    buffer = io.BytesIO()
238
    sf.write(buffer, repeated_audio, sr, format="WAV")
239
    buffer.seek(0)
240
    translation = await client.audio.translations.create(
241
        model=model_name,
242
        file=buffer,
243
        extra_body=dict(language="it", to_language="en"),
244
        response_format="text",
245
246
247
        temperature=0.0,
    )
    out = json.loads(translation)["text"].strip().lower()
248
    assert out.count("greek sea") == 2
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269


@pytest.mark.asyncio
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
    client, model_name = client_and_model
    transcription = await client.audio.translations.create(
        model=model_name,
        file=mary_had_lamb,
        response_format="text",
        temperature=0.0,
        extra_body={"max_completion_tokens": 1},
    )
    out = json.loads(transcription)
    out_text = out["text"]
    print(out_text)
    from transformers import AutoTokenizer

    tok = AutoTokenizer.from_pretrained(model_name)
    out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
    assert len(out_tokens) == 1
    # max_completion_tokens > max_model_len
270
    # max_model_len=32768 for Gemma-3n-E2B-it
271
272
273
274
275
    transcription = await client.audio.transcriptions.create(
        model=model_name,
        file=mary_had_lamb,
        response_format="text",
        temperature=0.0,
276
277
278
279
        extra_body={
            "max_completion_tokens": int(1e6),
            "repetition_penalty": 1.3,
        },
280
281
282
283
284
285
    )
    out = json.loads(transcription)
    out_text = out["text"]
    print(out_text)
    out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
    assert len(out_tokens) < 450  # ~Whisper max output len