test_llava.py 4.99 KB
Newer Older
1
2
3
import gc
from dataclasses import fields
from enum import Enum
4
from typing import Any, Dict, List, Tuple
5
6
7
8
9
10
11

import pytest
import torch
from transformers import AutoTokenizer

from vllm.config import VisionLanguageConfig

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

def iter_llava_configs(model_name: str):
    image_hw_to_feature_size = {
        (336, 336): 576,
    }

    for (h, w), f in image_hw_to_feature_size.items():
        for input_type, input_shape in [
            (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
            (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
        ]:
            yield (model_name,
                   VisionLanguageConfig(image_input_type=input_type,
                                        image_feature_size=f,
                                        image_token_id=32000,
                                        image_input_shape=input_shape,
                                        image_processor=model_name,
                                        image_processor_revision=None))


32
model_and_vl_config = [
33
34
35
    *iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
    # Not enough memory
    # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
36
37
38
]


39
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]:
40
41
42
43
44
    """Flatten vision language config to pure args.

    Compatible with what llm entrypoint expects.
    """
    result = {}
45
46
    for field in fields(vlm_config):
        value = getattr(vlm_config, field.name)
47
48
49
50
51
52
        if isinstance(value, Enum):
            result[field.name] = value.name.lower()
        elif isinstance(value, tuple):
            result[field.name] = ",".join([str(item) for item in value])
        else:
            result[field.name] = value
53
54
55

    result["disable_image_processor"] = vlm_config.image_processor is None

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    return result


def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
                         vision_language_config: VisionLanguageConfig,
                         model_id: str):
    """Sanitize vllm output to be comparable with hf output.
    The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
    x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
    It also reduces `output_str` from "<image><image>bla" to "bla".
    """
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    image_token_str = tokenizer.decode(vision_language_config.image_token_id)
    image_token_str_len = len(image_token_str)
    input_ids, output_str = vllm_output
    sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config
                                                     .image_feature_size - 1:]
    sanitzied_output_str = output_str[vision_language_config.
                                      image_feature_size *
                                      image_token_str_len:]
    return sanitized_input_ids, sanitzied_output_str


@pytest.mark.parametrize("worker_use_ray", [False])
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
84
85
                vllm_image_prompts, vllm_images, model_and_config, dtype: str,
                max_tokens: int, worker_use_ray: bool) -> None:
86
87
88
    """Inference result should be the same between hf and vllm.

    All the image fixtures for the test is under tests/images.
89
90
    For huggingface runner, we provide the PIL images as input.
    For vllm runner, we provide MultiModalData objects and corresponding
91
92
93
94
95
    vision language config as input.
    Note, the text input is also adjusted to abide by vllm contract.
    The text output is sanitized to be able to compare with hf.
    """
    model_id, vision_language_config = model_and_config
96

97
    hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
98
99
100
101
102
103
104
105
    hf_outputs = hf_model.generate_greedy(hf_image_prompts,
                                          max_tokens,
                                          images=hf_images)
    del hf_model

    vllm_model = vllm_runner(model_id,
                             dtype=dtype,
                             worker_use_ray=worker_use_ray,
106
                             enforce_eager=True,
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                             **as_dict(vision_language_config))
    vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
                                              max_tokens,
                                              images=vllm_images)
    del vllm_model

    gc.collect()
    torch.cuda.empty_cache()

    for i in range(len(hf_image_prompts)):
        hf_output_ids, hf_output_str = hf_outputs[i]
        vllm_output_ids, vllm_output_str = sanitize_vllm_output(
            vllm_outputs[i], vision_language_config, model_id)
        assert hf_output_str == vllm_output_str, (
            f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
        assert hf_output_ids == vllm_output_ids, (
            f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
124
125
126
127


# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# (Requires multiple GPUs)