test_ultravox.py 5.42 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import json
4
from typing import Any
5
6
7

import numpy as np
import pytest
8
import pytest_asyncio
9
from transformers import AutoTokenizer
10

11
from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
12
from ....utils import RemoteOpenAIServer
13
from ...registry import HF_EXAMPLE_MODELS
14

15
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
16

17
18
19
20
21
22
23
24
25
AUDIO_PROMPTS = AUDIO_ASSETS.prompts({
    "mary_had_lamb":
    "Transcribe this into English.",
    "winning_call":
    "What is happening in this audio clip?",
})

MULTI_AUDIO_PROMPT = "Describe each of the audios above."

26
AudioTuple = tuple[np.ndarray, int]
27

28
VLLM_PLACEHOLDER = "<|audio|>"
29
30
HF_PLACEHOLDER = "<|audio|>"

31
32
33
34
35
36
37
CHUNKED_PREFILL_KWARGS = {
    "enable_chunked_prefill": True,
    "max_num_seqs": 2,
    # Use a very small limit to exercise chunked prefill.
    "max_num_batched_tokens": 16
}

38

39
40
41
42
43
44
45
46
47
48
49
50
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
    """Convert kwargs to CLI args."""
    args = []
    for key, value in params_kwargs.items():
        if isinstance(value, bool):
            if value:
                args.append(f"--{key.replace('_','-')}")
        else:
            args.append(f"--{key.replace('_','-')}={value}")
    return args


51
52
53
54
@pytest.fixture(params=[
    pytest.param({}, marks=pytest.mark.cpu_model),
    pytest.param(CHUNKED_PREFILL_KWARGS),
])
55
def server(request, audio_assets: AudioTestAssets):
56
    args = [
57
58
        "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
        "--limit-mm-per-prompt",
59
        json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
60
    ] + params_kwargs_to_cli_args(request.param)
61

62
63
64
65
    with RemoteOpenAIServer(MODEL_NAME,
                            args,
                            env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
                                      "30"}) as remote_server:
66
67
68
69
70
71
72
73
74
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


75
76
77
def _get_prompt(audio_count, question, placeholder):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    placeholder = f"{placeholder}\n" * audio_count
78

79
80
81
82
83
84
    return tokenizer.apply_chat_template([{
        'role': 'user',
        'content': f"{placeholder}{question}"
    }],
                                         tokenize=False,
                                         add_generation_prompt=True)
85
86


87
def run_multi_audio_test(
88
89
    vllm_runner: type[VllmRunner],
    prompts_and_audios: list[tuple[str, list[AudioTuple]]],
90
91
92
93
94
    model: str,
    *,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
95
    **kwargs,
96
):
97
98
99
100
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")

101
102
103
104
105
106
    with vllm_runner(model,
                     dtype=dtype,
                     enforce_eager=True,
                     limit_mm_per_prompt={
                         "audio":
                         max((len(audio) for _, audio in prompts_and_audios))
107
108
                     },
                     **kwargs) as vllm_model:
109
110
111
112
113
114
115
116
117
118
119
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            [prompt for prompt, _ in prompts_and_audios],
            max_tokens,
            num_logprobs=num_logprobs,
            audios=[audios for _, audios in prompts_and_audios])

    # The HuggingFace model doesn't support multiple audios yet, so
    # just assert that some tokens were generated.
    assert all(tokens for tokens, *_ in vllm_outputs)


120
@pytest.mark.core_model
121
122
123
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
124
125
126
127
@pytest.mark.parametrize("vllm_kwargs", [
    pytest.param({}, marks=pytest.mark.cpu_model),
    pytest.param(CHUNKED_PREFILL_KWARGS),
])
128
129
130
def test_models_with_multiple_audios(vllm_runner,
                                     audio_assets: AudioTestAssets, dtype: str,
                                     max_tokens: int, num_logprobs: int,
131
                                     vllm_kwargs: dict) -> None:
132

133
    vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT,
134
135
136
137
138
                              VLLM_PLACEHOLDER)
    run_multi_audio_test(
        vllm_runner,
        [(vllm_prompt, [audio.audio_and_sample_rate
                        for audio in audio_assets])],
139
140
141
142
        MODEL_NAME,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
143
        **vllm_kwargs,
144
    )
145
146
147


@pytest.mark.asyncio
148
async def test_online_serving(client, audio_assets: AudioTestAssets):
149
    """Exercises online serving with/without chunked prefill enabled."""
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    messages = [{
        "role":
        "user",
        "content": [
            *[{
                "type": "audio_url",
                "audio_url": {
                    "url": audio.url
                }
            } for audio in audio_assets],
            {
                "type":
                "text",
                "text":
                f"What's happening in these {len(audio_assets)} audio clips?"
            },
        ],
    }]

    chat_completion = await client.chat.completions.create(model=MODEL_NAME,
                                                           messages=messages,
                                                           max_tokens=10)

    assert len(chat_completion.choices) == 1
    choice = chat_completion.choices[0]
    assert choice.finish_reason == "length"