test_mapper.py 4.83 KB
Newer Older
1
2
from contextlib import nullcontext

3
4
import numpy as np
import pytest
5
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
6

7
from vllm.config import ModelConfig
8
from vllm.multimodal import MultiModalRegistry
9
from vllm.multimodal.utils import rescale_image_size
10

11

12
13
14
15
16
@pytest.fixture
def mm_registry():
    return MultiModalRegistry()


17
@pytest.mark.parametrize("dtype", ["half", "float"])
18
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
19
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
20
21
22
23
24
25
26
    MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

    hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
    assert isinstance(hf_processor, CLIPImageProcessor)

    model_config = ModelConfig(
        model=MODEL_NAME,
27
        task="auto",
28
29
30
31
32
33
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype=dtype,
        revision=None,
34
        limit_mm_per_prompt={"image": 1},
35
    )
36

37
    mm_registry.init_mm_limits_per_prompt(model_config)
38

39
    for asset in image_assets:
40
41
        image = rescale_image_size(asset.pil_image, size_factor)

42
        hf_result = hf_processor.preprocess(
43
            image,
44
            return_tensors="pt",
45
        )
46
        vllm_result = mm_registry.map_input(
47
            model_config,
48
            {"image": image},
49
50
51
        )

        assert hf_result.keys() == vllm_result.keys()
52
53
        for key, hf_tensor in hf_result.items():
            hf_arr: np.ndarray = hf_tensor.numpy()
54
55
56
57
58
59
            vllm_arr: np.ndarray = vllm_result[key].numpy()

            assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
            assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"


60
@pytest.mark.parametrize("dtype", ["half", "float"])
61
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
62
63
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
                                    size_factor):
64
    MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
65
66
67
68
69
70

    hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
    assert isinstance(hf_processor, LlavaNextImageProcessor)

    model_config = ModelConfig(
        model=MODEL_NAME,
71
        task="auto",
72
73
74
75
76
77
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype=dtype,
        revision=None,
78
        limit_mm_per_prompt={"image": 1},
79
    )
80

81
    mm_registry.init_mm_limits_per_prompt(model_config)
82

83
    for asset in image_assets:
84
85
        image = rescale_image_size(asset.pil_image, size_factor)

86
        hf_result = hf_processor.preprocess(
87
            image,
88
            return_tensors="pt",
89
        )
90
        vllm_result = mm_registry.map_input(
91
            model_config,
92
            {"image": image},
93
94
95
96
97
98
99
100
101
        )

        assert hf_result.keys() == vllm_result.keys()
        for key, hf_tensor in hf_result.items():
            hf_arr: np.ndarray = hf_tensor.numpy()
            vllm_arr: np.ndarray = vllm_result[key].numpy()

            assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
            assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
102
103
104
105
106
107
108
109
110
111
112
113


@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
    MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

    model_config = ModelConfig(
        model=MODEL_NAME,
114
        task="auto",
115
116
117
118
119
120
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
121
        limit_mm_per_prompt={"image": limit},
122
123
    )

124
    mm_registry.init_mm_limits_per_prompt(model_config)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    image = image_assets[0].pil_image
    if num_images == 0:
        mm_inputs = {}
    elif num_images == 1:
        mm_inputs = {"image": image}
    else:
        mm_inputs = {"image": [image] * num_images}

    with nullcontext() if is_valid else pytest.raises(ValueError):
        mm_registry.map_input(model_config, mm_inputs)


# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
    MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

    model_config = ModelConfig(
        model=MODEL_NAME,
145
        task="auto",
146
147
148
149
150
151
        tokenizer=MODEL_NAME,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
152
        limit_mm_per_prompt={"image": num_images},
153
154
    )

155
    mm_registry.init_mm_limits_per_prompt(model_config)
156
157
158
159
160
161

    image = image_assets[0].pil_image
    mm_inputs = {"image": [image] * num_images}

    mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
    assert len(mapped_inputs["pixel_values"]) == num_images