test_pixtral.py 7.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import json
from dataclasses import asdict
5
from typing import TYPE_CHECKING, Any
6

Patrick von Platen's avatar
Patrick von Platen committed
7
import pytest
8
from mistral_common.multimodal import download_image
9
from mistral_common.protocol.instruct.chunk import ImageURLChunk
10
11
12
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
13
from transformers import AutoProcessor
14

15
from vllm import SamplingParams, TextPrompt, TokensPrompt
16
from vllm.inputs import MultiModalDataBuiltins
17
from vllm.logprobs import Logprob, SampleLogprobs
18
from vllm.platforms import current_platform
Patrick von Platen's avatar
Patrick von Platen committed
19

20
from ....utils import VLLM_PATH, large_gpu_test
21
from ...utils import check_logprobs_close
Patrick von Platen's avatar
Patrick von Platen committed
22

23
24
if TYPE_CHECKING:
    from _typeshed import StrPath
Patrick von Platen's avatar
Patrick von Platen committed
25

26
27
PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
28
MINISTRAL_3B_ID = "mistralai/Ministral-3-3B-Instruct-2512"
29
30
31

MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]

32
IMG_URLS = [
33
34
35
36
    "237-400x300.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "231-200x300.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "27-500x500.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "17-150x600.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
37
38
39
40
]
PROMPT = "Describe each image in one short sentence."


41
def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]:
42
43
44
45
46
47
48
49
50
51
52
53
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": PROMPT,
                }
            ]
            + [{"type": "image_url", "image_url": {"url": url}} for url in urls],
        }
    ]
54
55


56
def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]:
57
58
59
60
61
62
63
64
65
66
67
68
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "content": PROMPT,
                },
                *({"type": "image", "image": download_image(url)} for url in urls),
            ],
        }
    ]
69
70


71
def _create_engine_inputs(urls: list[str]) -> TokensPrompt:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    msg = _create_msg_format(urls)

    tokenizer = MistralTokenizer.from_model("pixtral")

    request = ChatCompletionRequest(messages=msg)  # type: ignore[type-var]
    tokenized = tokenizer.encode_chat_completion(request)

    engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)

    images = []
    for chunk in request.messages[0].content:
        if isinstance(chunk, ImageURLChunk):
            images.append(image_from_chunk(chunk))

    mm_data = MultiModalDataBuiltins(image=images)
    engine_inputs["multi_modal_data"] = mm_data

    return engine_inputs


92
def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    msg = _create_msg_format_hf(urls)

    tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b")
    prompt = tokenizer.apply_chat_template(msg)

    images = []
    for chunk in msg[0]["content"]:
        if chunk["type"] == "image":
            images.append(chunk["image"])

    mm_data = MultiModalDataBuiltins(image=images)
    engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)

    return engine_inputs


109
110
111
112
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)

MAX_MODEL_LEN = [8192, 65536]
113
114
115
116

FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
assert FIXTURES_PATH.exists()

117
118
119
FIXTURE_LOGPROBS_CHAT = {
    PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
    MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
120
    MINISTRAL_3B_ID: FIXTURES_PATH / "ministral_3b_chat.json",
121
}
122

123
OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]]
124

125
126

# For the test author to store golden output in JSON
127
128
129
130
def _dump_outputs_w_logprobs(
    outputs: OutputsLogprobs,
    filename: "StrPath",
) -> None:
131
132
133
134
135
136
137
138
139
140
141
    json_data = [
        (
            tokens,
            text,
            [
                {k: asdict(v) for k, v in token_logprobs.items()}
                for token_logprobs in (logprobs or [])
            ],
        )
        for tokens, text, logprobs in outputs
    ]
142
143
144
145
146

    with open(filename, "w") as f:
        json.dump(json_data, f)


147
def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
148
149
150
    with open(filename, "rb") as f:
        json_data = json.load(f)

151
152
153
154
155
156
157
158
159
160
161
    return [
        (
            tokens,
            text,
            [
                {int(k): Logprob(**v) for k, v in token_logprobs.items()}
                for token_logprobs in logprobs
            ],
        )
        for tokens, text, logprobs in json_data
    ]
Patrick von Platen's avatar
Patrick von Platen committed
162
163


164
@large_gpu_test(min_gb=80)
Patrick von Platen's avatar
Patrick von Platen committed
165
@pytest.mark.parametrize("model", MODELS)
166
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
Patrick von Platen's avatar
Patrick von Platen committed
167
@pytest.mark.parametrize("dtype", ["bfloat16"])
168
169
170
def test_chat(
    vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server
) -> None:
171
172
173
174
175
176
177
178
179
    if (
        model == MISTRAL_SMALL_3_1_ID
        and max_model_len == 65536
        and current_platform.is_rocm()
    ):
        pytest.skip(
            "OOM on ROCm: 24B model with 65536 context length exceeds GPU memory"
        )

180
    EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model])
181
    with vllm_runner(
182
183
184
185
186
187
188
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        load_format="mistral",
        config_format="mistral",
        max_model_len=max_model_len,
        limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
189
190
    ) as vllm_model:
        outputs = []
191
192
193
194
195
196
197
198

        urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
        msgs = [
            _create_msg_format(urls_all[:1]),
            _create_msg_format(urls_all[:2]),
            _create_msg_format(urls_all),
        ]
        for msg in msgs:
199
            output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
200
201
202
203

            outputs.extend(output)

    logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
204
205
206
207
    # Remove last `None` prompt_logprobs to compare with fixture
    for i in range(len(logprobs)):
        assert logprobs[i][-1] is None
        logprobs[i] = logprobs[i][:-1]
208
209
210
211
212
213
    check_logprobs_close(
        outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
        outputs_1_lst=logprobs,
        name_0="h100_ref",
        name_1="output",
    )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


@large_gpu_test(min_gb=16)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_chat_consolidated(vllm_runner, dtype: str, local_asset_server) -> None:
    EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
        FIXTURE_LOGPROBS_CHAT[MINISTRAL_3B_ID]
    )
    with vllm_runner(
        MINISTRAL_3B_ID,
        dtype=dtype,
        tokenizer_mode="mistral",
        load_format="mistral",
        config_format="mistral",
        max_model_len=8192,
        limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
    ) as vllm_model:
        outputs = []
        urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
        msgs = [
            _create_msg_format(urls_all[:1]),
            _create_msg_format(urls_all[:2]),
            _create_msg_format(urls_all),
        ]
        for msg in msgs:
            output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
            outputs.extend(output)

    logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
    for i in range(len(logprobs)):
        assert logprobs[i][-1] is None
        logprobs[i] = logprobs[i][:-1]
    check_logprobs_close(
        outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
        outputs_1_lst=logprobs,
        name_0="h100_ref",
        name_1="output",
    )