test_ultravox.py 5.48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from typing import Any
6
7
8

import numpy as np
import pytest
9
import pytest_asyncio
10
from transformers import AutoTokenizer
11

12
from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
13
from ....utils import RemoteOpenAIServer
14
from ...registry import HF_EXAMPLE_MODELS
15

16
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
17

18
19
20
21
22
23
24
25
26
AUDIO_PROMPTS = AUDIO_ASSETS.prompts({
    "mary_had_lamb":
    "Transcribe this into English.",
    "winning_call":
    "What is happening in this audio clip?",
})

MULTI_AUDIO_PROMPT = "Describe each of the audios above."

27
AudioTuple = tuple[np.ndarray, int]
28

29
VLLM_PLACEHOLDER = "<|audio|>"
30
31
HF_PLACEHOLDER = "<|audio|>"

32
33
34
35
36
37
38
CHUNKED_PREFILL_KWARGS = {
    "enable_chunked_prefill": True,
    "max_num_seqs": 2,
    # Use a very small limit to exercise chunked prefill.
    "max_num_batched_tokens": 16
}

39

40
41
42
43
44
45
46
47
48
49
50
51
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
    """Convert kwargs to CLI args."""
    args = []
    for key, value in params_kwargs.items():
        if isinstance(value, bool):
            if value:
                args.append(f"--{key.replace('_','-')}")
        else:
            args.append(f"--{key.replace('_','-')}={value}")
    return args


52
53
54
55
@pytest.fixture(params=[
    pytest.param({}, marks=pytest.mark.cpu_model),
    pytest.param(CHUNKED_PREFILL_KWARGS),
])
56
def server(request, audio_assets: AudioTestAssets):
57
    args = [
58
59
        "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
        "--limit-mm-per-prompt",
60
        json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
61
    ] + params_kwargs_to_cli_args(request.param)
62

63
64
65
66
    with RemoteOpenAIServer(MODEL_NAME,
                            args,
                            env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
                                      "30"}) as remote_server:
67
68
69
70
71
72
73
74
75
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


76
77
78
def _get_prompt(audio_count, question, placeholder):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    placeholder = f"{placeholder}\n" * audio_count
79

80
81
82
83
84
85
    return tokenizer.apply_chat_template([{
        'role': 'user',
        'content': f"{placeholder}{question}"
    }],
                                         tokenize=False,
                                         add_generation_prompt=True)
86
87


88
def run_multi_audio_test(
89
90
    vllm_runner: type[VllmRunner],
    prompts_and_audios: list[tuple[str, list[AudioTuple]]],
91
92
93
94
95
    model: str,
    *,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
96
    **kwargs,
97
):
98
99
100
101
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")

102
103
104
105
106
107
    with vllm_runner(model,
                     dtype=dtype,
                     enforce_eager=True,
                     limit_mm_per_prompt={
                         "audio":
                         max((len(audio) for _, audio in prompts_and_audios))
108
109
                     },
                     **kwargs) as vllm_model:
110
111
112
113
114
115
116
117
118
119
120
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            [prompt for prompt, _ in prompts_and_audios],
            max_tokens,
            num_logprobs=num_logprobs,
            audios=[audios for _, audios in prompts_and_audios])

    # The HuggingFace model doesn't support multiple audios yet, so
    # just assert that some tokens were generated.
    assert all(tokens for tokens, *_ in vllm_outputs)


121
@pytest.mark.core_model
122
123
124
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
125
126
127
128
@pytest.mark.parametrize("vllm_kwargs", [
    pytest.param({}, marks=pytest.mark.cpu_model),
    pytest.param(CHUNKED_PREFILL_KWARGS),
])
129
130
131
def test_models_with_multiple_audios(vllm_runner,
                                     audio_assets: AudioTestAssets, dtype: str,
                                     max_tokens: int, num_logprobs: int,
132
                                     vllm_kwargs: dict) -> None:
133

134
    vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT,
135
136
137
138
139
                              VLLM_PLACEHOLDER)
    run_multi_audio_test(
        vllm_runner,
        [(vllm_prompt, [audio.audio_and_sample_rate
                        for audio in audio_assets])],
140
141
142
143
        MODEL_NAME,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
144
        **vllm_kwargs,
145
    )
146
147
148


@pytest.mark.asyncio
149
async def test_online_serving(client, audio_assets: AudioTestAssets):
150
    """Exercises online serving with/without chunked prefill enabled."""
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    messages = [{
        "role":
        "user",
        "content": [
            *[{
                "type": "audio_url",
                "audio_url": {
                    "url": audio.url
                }
            } for audio in audio_assets],
            {
                "type":
                "text",
                "text":
                f"What's happening in these {len(audio_assets)} audio clips?"
            },
        ],
    }]

    chat_completion = await client.chat.completions.create(model=MODEL_NAME,
                                                           messages=messages,
                                                           max_tokens=10)

    assert len(chat_completion.choices) == 1
    choice = chat_completion.choices[0]
    assert choice.finish_reason == "length"