test_multimodal_broadcast.py 1.87 KB
Newer Older
1
2
3
4
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.

Run:
```sh
5
pytest -s -v test_multimodal_broadcast.py
6
7
8
9
10
11
12
```
"""

import pytest

from vllm.utils import cuda_device_count_stateless

13
14
15
16
17
18
19
20
from ..utils import fork_new_process_for_each_test


@pytest.mark.skipif(cuda_device_count_stateless() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend", [
    ("llava-hf/llava-1.5-7b-hf", "ray"),
    ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
21
    ("facebook/chameleon-7b", "ray"),
22
23
    ("llava-hf/llava-1.5-7b-hf", "mp"),
    ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
24
    ("facebook/chameleon-7b", "mp"),
25
26
27
28
29
30
31
32
33
34
35
36
37
])
@fork_new_process_for_each_test
def test_models(hf_runner, vllm_runner, image_assets, model: str,
                distributed_executor_backend: str) -> None:

    dtype = "half"
    max_tokens = 5
    num_logprobs = 5
    tensor_parallel_size = 2

    if model.startswith("llava-hf/llava-1.5"):
        from ..models.test_llava import models, run_test
    elif model.startswith("llava-hf/llava-v1.6"):
38
39
        from ..models.test_llava_next import run_test  # type: ignore[no-redef]
        from ..models.test_llava_next import models
40
    elif model.startswith("facebook/chameleon"):
41
42
        from ..models.test_chameleon import run_test  # type: ignore[no-redef]
        from ..models.test_chameleon import models
43
44
    else:
        raise NotImplementedError(f"Unsupported model: {model}")
45
46
47
48
49

    run_test(
        hf_runner,
        vllm_runner,
        image_assets,
50
        model=models[0],
51
52
        # So that LLaVA-NeXT processor may return nested list
        size_factors=[0.25, 0.5, 1.0],
53
54
        dtype=dtype,
        max_tokens=max_tokens,
55
        num_logprobs=num_logprobs,
56
57
58
        tensor_parallel_size=tensor_parallel_size,
        distributed_executor_backend=distributed_executor_backend,
    )