test_whisper.py 5.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
from collections.abc import Sequence
from typing import Any

import librosa
8
import pytest
9
from transformers import AutoModelForSpeechSeq2Seq
10
11

from vllm.assets.audio import AudioAsset
12
from vllm.platforms import current_platform
13

14
from ....conftest import HfRunner, PromptAudioInput, VllmRunner
15
from ....utils import create_new_process_for_each_test, multi_gpu_test
16
17
18
19
20
21
22
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close

VLLM_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
HF_PROMPT = ""
# Whisper expects 16kHz audio
WHISPER_SAMPLE_RATE = 16000
23

24
25
26
27
28

@pytest.fixture(autouse=True)
def use_spawn_for_whisper(monkeypatch):
    """Whisper has issues with forked workers, use spawn instead."""
    monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
29
30
31


def run_test(
32
    hf_runner: type[HfRunner],
33
    vllm_runner: type[VllmRunner],
34
    inputs: Sequence[tuple[list[str], list[str], PromptAudioInput]],
35
36
    model: str,
    *,
37
38
39
40
    max_model_len: int,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
41
    tensor_parallel_size: int,
42
    distributed_executor_backend: str | None = None,
43
    enforce_eager: bool = True,
44
) -> None:
45
    """Inference result should be the same between hf and vllm.
46

47
48
49
50
51
    All the audio fixtures for the test are from AudioAsset.
    For huggingface runner, we provide the audio as input.
    For vllm runner, we provide MultiModalDataDict objects
    and corresponding MultiModalConfig as input.
    """
52
    with vllm_runner(
53
        model,
54
        dtype=dtype,
55
        max_model_len=max_model_len,
56
57
        tensor_parallel_size=tensor_parallel_size,
        distributed_executor_backend=distributed_executor_backend,
58
59
60
        limit_mm_per_prompt={"audio": 2},
        enforce_eager=enforce_eager,
        disable_custom_all_reduce=True,
61
    ) as vllm_model:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        vllm_outputs_per_case = [
            vllm_model.generate_greedy_logprobs(
                vllm_prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                audios=audios,
            )
            for vllm_prompts, _, audios in inputs
        ]

    with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
        hf_outputs_per_case = [
            hf_model.generate_greedy_logprobs_limit(
                hf_prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                audios=audios,
            )
            for _, hf_prompts, audios in inputs
        ]

    for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
        check_logprobs_close(
            outputs_0_lst=hf_outputs,
            outputs_1_lst=vllm_outputs,
            name_0="hf",
            name_1="vllm",
89
        )
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
@pytest.fixture
def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
    audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
    inputs = []
    for asset in audio_assets:
        audio, orig_sr = asset.audio_and_sample_rate
        # Resample to Whisper's expected sample rate (16kHz)
        if orig_sr != WHISPER_SAMPLE_RATE:
            audio = librosa.resample(
                audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
            )
        # vLLM prompts, HF prompts, audio inputs
        inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)]))
    return inputs
106
107


108
109
110
111
def check_model_available(model: str) -> None:
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")
    model_info.check_transformers_version(on_fail="skip")
112
113


114
@pytest.mark.core_model
115
116
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
117
@pytest.mark.parametrize("dtype", ["half", "float"])
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("enforce_eager", [True, False])
@create_new_process_for_each_test("spawn")
def test_models(
    hf_runner,
    vllm_runner,
    model: str,
    dtype: str,
    num_logprobs: int,
    input_audios,
    enforce_eager: bool,
) -> None:
    check_model_available(model)
    if current_platform.is_cpu() and not enforce_eager:
        pytest.skip("Skipping test for CPU with non-eager mode")
133
    run_test(
134
        hf_runner,
135
        vllm_runner,
136
        input_audios,
137
138
        model,
        dtype=dtype,
139
140
141
142
143
        max_model_len=448,
        max_tokens=200,
        num_logprobs=num_logprobs,
        tensor_parallel_size=1,
        enforce_eager=enforce_eager,
144
    )
145
146
147
148
149
150


@multi_gpu_test(num_gpus=2)
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
151
152
153
154
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [200])
@pytest.mark.parametrize("num_logprobs", [5])
@create_new_process_for_each_test("spawn")
155
def test_models_distributed(
156
    hf_runner,
157
    vllm_runner,
158
159
160
161
162
163
    model: str,
    distributed_executor_backend: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    input_audios,
164
) -> None:
165
    check_model_available(model)
166
    run_test(
167
        hf_runner,
168
        vllm_runner,
169
        input_audios,
170
        model,
171
172
173
174
        dtype=dtype,
        max_model_len=448,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
175
176
        tensor_parallel_size=2,
        distributed_executor_backend=distributed_executor_backend,
177
        enforce_eager=False,
178
    )