test_pixtral.py 6.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import json
from dataclasses import asdict
5
from typing import TYPE_CHECKING, Any
6

7
import os
Patrick von Platen's avatar
Patrick von Platen committed
8
import pytest
9
from mistral_common.multimodal import download_image
10
from mistral_common.protocol.instruct.chunk import ImageURLChunk
11
12
13
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
14
from transformers import AutoProcessor
15

16
from vllm import SamplingParams, TextPrompt, TokensPrompt
17
from vllm.logprobs import Logprob, SampleLogprobs
18
from vllm.multimodal import MultiModalDataBuiltins
Patrick von Platen's avatar
Patrick von Platen committed
19

20
from ....utils import VLLM_PATH, large_gpu_test
21
from ...utils import check_logprobs_close, models_path_prefix
Patrick von Platen's avatar
Patrick von Platen committed
22

23
24
if TYPE_CHECKING:
    from _typeshed import StrPath
Patrick von Platen's avatar
Patrick von Platen committed
25

26
27
PIXTRAL_ID = os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409")
MISTRAL_SMALL_3_1_ID = os.path.join(models_path_prefix, "mistralai/Mistral-Small-3.1-24B-Instruct-2503")
28

zhuwenwen's avatar
zhuwenwen committed
29
MODELS = [os.path.join(models_path_prefix, PIXTRAL_ID), os.path.join(models_path_prefix, MISTRAL_SMALL_3_1_ID)]
30

31
IMG_URLS = [
32
33
34
35
    "237-400x300.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "231-200x300.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "27-500x500.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
    "17-150x600.jpg",  # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg",
36
37
38
39
]
PROMPT = "Describe each image in one short sentence."


40
def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]:
41
42
43
44
45
46
47
48
49
50
51
52
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": PROMPT,
                }
            ]
            + [{"type": "image_url", "image_url": {"url": url}} for url in urls],
        }
    ]
53
54


55
def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]:
56
57
58
59
60
61
62
63
64
65
66
67
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "content": PROMPT,
                },
                *({"type": "image", "image": download_image(url)} for url in urls),
            ],
        }
    ]
68
69


70
def _create_engine_inputs(urls: list[str]) -> TokensPrompt:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    msg = _create_msg_format(urls)

    tokenizer = MistralTokenizer.from_model("pixtral")

    request = ChatCompletionRequest(messages=msg)  # type: ignore[type-var]
    tokenized = tokenizer.encode_chat_completion(request)

    engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)

    images = []
    for chunk in request.messages[0].content:
        if isinstance(chunk, ImageURLChunk):
            images.append(image_from_chunk(chunk))

    mm_data = MultiModalDataBuiltins(image=images)
    engine_inputs["multi_modal_data"] = mm_data

    return engine_inputs


91
def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt:
92
93
    msg = _create_msg_format_hf(urls)

zhuwenwen's avatar
zhuwenwen committed
94
    tokenizer = AutoProcessor.from_pretrained(os.path.join(models_path_prefix, "mistral-community/pixtral-12b"))
95
96
97
98
99
100
101
102
103
104
105
106
107
    prompt = tokenizer.apply_chat_template(msg)

    images = []
    for chunk in msg[0]["content"]:
        if chunk["type"] == "image":
            images.append(chunk["image"])

    mm_data = MultiModalDataBuiltins(image=images)
    engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data)

    return engine_inputs


108
109
110
111
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)

MAX_MODEL_LEN = [8192, 65536]
112
113
114
115

FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
assert FIXTURES_PATH.exists()

116
117
118
119
FIXTURE_LOGPROBS_CHAT = {
    PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
    MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
}
120

121
OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]]
122

123
124

# For the test author to store golden output in JSON
125
126
127
128
def _dump_outputs_w_logprobs(
    outputs: OutputsLogprobs,
    filename: "StrPath",
) -> None:
129
130
131
132
133
134
135
136
137
138
139
    json_data = [
        (
            tokens,
            text,
            [
                {k: asdict(v) for k, v in token_logprobs.items()}
                for token_logprobs in (logprobs or [])
            ],
        )
        for tokens, text, logprobs in outputs
    ]
140
141
142
143
144

    with open(filename, "w") as f:
        json.dump(json_data, f)


145
def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
146
147
148
    with open(filename, "rb") as f:
        json_data = json.load(f)

149
150
151
152
153
154
155
156
157
158
159
    return [
        (
            tokens,
            text,
            [
                {int(k): Logprob(**v) for k, v in token_logprobs.items()}
                for token_logprobs in logprobs
            ],
        )
        for tokens, text, logprobs in json_data
    ]
Patrick von Platen's avatar
Patrick von Platen committed
160
161


162
@large_gpu_test(min_gb=80)
Patrick von Platen's avatar
Patrick von Platen committed
163
@pytest.mark.parametrize("model", MODELS)
164
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
Patrick von Platen's avatar
Patrick von Platen committed
165
@pytest.mark.parametrize("dtype", ["bfloat16"])
166
167
168
169
def test_chat(
    vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server
) -> None:
    EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model])
170
    with vllm_runner(
171
172
173
174
175
176
177
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        load_format="mistral",
        config_format="mistral",
        max_model_len=max_model_len,
        limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
178
179
    ) as vllm_model:
        outputs = []
180
181
182
183
184
185
186
187

        urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
        msgs = [
            _create_msg_format(urls_all[:1]),
            _create_msg_format(urls_all[:2]),
            _create_msg_format(urls_all),
        ]
        for msg in msgs:
188
            output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
189
190
191
192

            outputs.extend(output)

    logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
193
194
195
196
    # Remove last `None` prompt_logprobs to compare with fixture
    for i in range(len(logprobs)):
        assert logprobs[i][-1] is None
        logprobs[i] = logprobs[i][:-1]
197
198
199
200
201
202
    check_logprobs_close(
        outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
        outputs_1_lst=logprobs,
        name_0="h100_ref",
        name_1="output",
    )