test_transcription_validation.py 3.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# imports for structured outputs tests
5
6
7
8
9
import json

import pytest

from ...utils import RemoteOpenAIServer
10
from .conftest import add_attention_backend
11

Patrick von Platen's avatar
Patrick von Platen committed
12
MISTRAL_FORMAT_ARGS = [
13
14
15
16
17
18
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
Patrick von Platen's avatar
Patrick von Platen committed
19
20
]

21
22

@pytest.mark.asyncio
23
@pytest.mark.parametrize("model_name", ["mistralai/Voxtral-Mini-3B-2507"])
24
async def test_basic_audio(mary_had_lamb, model_name, rocm_aiter_fa_attention):
25
    server_args = ["--enforce-eager"]
Patrick von Platen's avatar
Patrick von Platen committed
26
27
28
29

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

30
31
    add_attention_backend(server_args, rocm_aiter_fa_attention)

32
33
34
35
36
37
38
39
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
40
41
            temperature=0.0,
        )
42
        out = json.loads(transcription)
43
44
        out_text = out["text"]
        out_usage = out["usage"]
45
46
        assert "Mary had a little lamb," in out_text
        assert out_usage["seconds"] == 16, out_usage["seconds"]
47
48


49
@pytest.mark.asyncio
50
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
51
    """Ensure STT (transcribe) requests can pass LoRA through to generate."""
52
53
54
55
56
    # ROCm SPECIFIC CONFIGURATION:
    # To ensure the test passes on ROCm, we modify the max model length to 512.
    # We DO NOT apply this to other platforms to maintain strict upstream parity.
    from vllm.platforms import current_platform

57
58
59
60
61
62
63
64
65
66
    model_name = "ibm-granite/granite-speech-3.3-2b"
    lora_model_name = "speech"
    server_args = [
        "--enforce-eager",
        "--enable-lora",
        "--max-lora-rank",
        "64",
        "--lora-modules",
        f"{lora_model_name}={model_name}",
        "--max-model-len",
67
        "512" if current_platform.is_rocm() else "2048",
68
69
70
71
        "--max-num-seqs",
        "1",
    ]

72
73
    add_attention_backend(server_args, rocm_aiter_fa_attention)

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
    with RemoteOpenAIServer(model_name, server_args) as remote_server:
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=lora_model_name,
            file=mary_had_lamb,
            language="en",
            response_format="text",
            temperature=0.0,
        )
    out = json.loads(transcription)
    out_text = out["text"]
    out_usage = out["usage"]
    assert "mary had a little lamb" in out_text
    assert out_usage["seconds"] == 16, out_usage["seconds"]


91
@pytest.mark.asyncio
92
async def test_basic_audio_gemma(foscolo, rocm_aiter_fa_attention):
93
94
95
96
97
    # Gemma accuracy on some of the audio samples we use is particularly bad,
    # hence we use a different one here. WER is evaluated separately.
    model_name = "google/gemma-3n-E2B-it"
    server_args = ["--enforce-eager"]

98
99
    add_attention_backend(server_args, rocm_aiter_fa_attention)

100
101
102
    with RemoteOpenAIServer(
        model_name, server_args, max_wait_seconds=480
    ) as remote_server:
103
104
105
106
107
108
        client = remote_server.get_async_client()
        transcription = await client.audio.transcriptions.create(
            model=model_name,
            file=foscolo,
            language="it",
            response_format="text",
109
110
111
            temperature=0.0,
        )
        out = json.loads(transcription)["text"]
112
        assert "da cui vergine nacque Venere" in out