test_transcription_validation.py 4.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
import json

import pytest

9
from ...utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
10
from .conftest import add_attention_backend
11

Patrick von Platen's avatar
Patrick von Platen committed
12
MISTRAL_FORMAT_ARGS = [
13
14
15
16
17
18
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
Patrick von Platen's avatar
Patrick von Platen committed
19
20
]

21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
async def transcribe_and_check(
    client,
    model_name: str,
    file,
    *,
    language: str,
    expected_text: str,
    expected_seconds: int | None = None,
    case_sensitive: bool = False,
):
    """Run a transcription request and assert the output contains
    *expected_text* and optionally that usage reports *expected_seconds*.

    Provides detailed failure messages with the actual transcription output.
    """
    transcription = await client.audio.transcriptions.create(
        model=model_name,
        file=file,
        language=language,
        response_format="text",
        temperature=0.0,
    )
    out = json.loads(transcription)
    out_text = out["text"]
    out_usage = out["usage"]

    if case_sensitive:
        assert expected_text in out_text, (
            f"Expected {expected_text!r} in transcription output, got: {out_text!r}"
        )
    else:
        assert expected_text.lower() in out_text.lower(), (
            f"Expected {expected_text!r} (case-insensitive) in transcription "
            f"output, got: {out_text!r}"
        )

    if expected_seconds is not None:
        assert out_usage["seconds"] == expected_seconds, (
            f"Expected {expected_seconds}s of audio, "
            f"got {out_usage['seconds']}s. Full usage: {out_usage!r}"
        )


65
@pytest.mark.asyncio
66
67
68
@pytest.mark.parametrize(
    "model_name", ["mistralai/Voxtral-Mini-3B-2507", "Qwen/Qwen3-ASR-0.6B"]
)
69
async def test_basic_audio(mary_had_lamb, model_name, rocm_aiter_fa_attention):
70
    server_args = ["--enforce-eager", *ROCM_EXTRA_ARGS]
Patrick von Platen's avatar
Patrick von Platen committed
71
72
73
74

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

75
76
    add_attention_backend(server_args, rocm_aiter_fa_attention)

77
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
78
79
80
    with RemoteOpenAIServer(
        model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
    ) as remote_server:
81
        client = remote_server.get_async_client()
82
83
84
85
        await transcribe_and_check(
            client,
            model_name,
            mary_had_lamb,
86
            language="en",
87
88
            expected_text="Mary had a little lamb",
            expected_seconds=16,
89
        )
90
91


92
@pytest.mark.asyncio
93
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
94
    """Ensure STT (transcribe) requests can pass LoRA through to generate."""
95
96
97
98
99
    # ROCm SPECIFIC CONFIGURATION:
    # To ensure the test passes on ROCm, we modify the max model length to 512.
    # We DO NOT apply this to other platforms to maintain strict upstream parity.
    from vllm.platforms import current_platform

100
101
102
103
104
105
106
107
108
109
    model_name = "ibm-granite/granite-speech-3.3-2b"
    lora_model_name = "speech"
    server_args = [
        "--enforce-eager",
        "--enable-lora",
        "--max-lora-rank",
        "64",
        "--lora-modules",
        f"{lora_model_name}={model_name}",
        "--max-model-len",
110
        "512" if current_platform.is_rocm() else "2048",
111
112
113
114
        "--max-num-seqs",
        "1",
    ]

115
116
    add_attention_backend(server_args, rocm_aiter_fa_attention)

117
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
118
119
120
    with RemoteOpenAIServer(
        model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
    ) as remote_server:
121
        client = remote_server.get_async_client()
122
123
124
125
        await transcribe_and_check(
            client,
            lora_model_name,
            mary_had_lamb,
126
            language="en",
127
128
            expected_text="mary had a little lamb",
            expected_seconds=16,
129
130
131
        )


132
@pytest.mark.asyncio
133
134
135
136
@pytest.mark.parametrize(
    "model_name", ["google/gemma-3n-E2B-it", "Qwen/Qwen3-ASR-0.6B"]
)
async def test_basic_audio_foscolo(foscolo, rocm_aiter_fa_attention, model_name):
137
138
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
139
    server_args = ["--enforce-eager", *ROCM_EXTRA_ARGS]
140

141
142
    add_attention_backend(server_args, rocm_aiter_fa_attention)

143
    with RemoteOpenAIServer(
144
145
146
147
        model_name,
        server_args,
        max_wait_seconds=480,
        env_dict=ROCM_ENV_OVERRIDES,
148
    ) as remote_server:
149
        client = remote_server.get_async_client()
150
151
152
153
        await transcribe_and_check(
            client,
            model_name,
            foscolo,
154
            language="it",
155
            expected_text="ove il mio corpo fanciulletto giacque",
156
        )