test_qwen2_vl.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Any, TypedDict
5
6
7
8
9
10

import numpy.typing as npt
import pytest
import torch
from PIL import Image

11
12
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
13

14
15
16
17
18
19
20
from ....conftest import (
    IMAGE_ASSETS,
    VIDEO_ASSETS,
    PromptImageInput,
    PromptVideoInput,
    VllmRunner,
)
21
22
from ...utils import check_logprobs_close

23
24

@pytest.fixture(scope="function", autouse=True)
25
26
27
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
28
29


30
31
32
33
34
models = ["Qwen/Qwen2-VL-2B-Instruct"]
target_dtype = "half"

IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
35
MODEL_HIDDEN_SIZE = 1536
36
37
38
39
40
41


def qwen2_vl_chat_template(*query):
    return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n"  # noqa: E501


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
    {
        "stop_sign": qwen2_vl_chat_template(
            IMAGE_PLACEHOLDER,
            "What is the biggest text's content in this image?",
        ),
        "cherry_blossom": qwen2_vl_chat_template(
            IMAGE_PLACEHOLDER,
            "What is the season shown in this image? ",
            "Reply with a short sentence (no more than 20 words)",
        ),
    }
)

VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
    {
        "baby_reading": qwen2_vl_chat_template(
            VIDEO_PLACEHOLDER,
            "Describe this video with a short sentence ",
            "(no more than 20 words)",
        ),
    }
)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

MULTIIMAGE_PROMPT = qwen2_vl_chat_template(
    IMAGE_PLACEHOLDER,
    IMAGE_PLACEHOLDER,
    "Describe these two images separately. ",
    "For each image, reply with a short sentence ",
    "(no more than 10 words).",
)


class Qwen2VLPromptImageEmbeddingInput(TypedDict):
    image_embeds: torch.Tensor
    image_grid_thw: torch.Tensor


class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
    video_embeds: torch.Tensor
    video_grid_thw: torch.Tensor


def batch_make_image_embeddings(
86
    image_batches: list[Image.Image | list[Image.Image]],
87
88
89
    processor,
    llm: VllmRunner,
) -> list[Qwen2VLPromptImageEmbeddingInput]:
90
91
    """batched image embeddings for Qwen2-VL

92
    This will infer all images' embeddings in a single batch,
93
94
95
      and split the result according to input batches.

    image_batches:
96
97
      - Single-image batches: `list[Image.Image]`
      - Multiple-image batches: `list[list[Image.Image]]]`
98

99
    returns: `list[Qwen2VLPromptImageEmbeddingInput]`
100
101
    """

102
    image_batches_: list[Any] = image_batches[:]
103
104
105
106
107
108
109
110
111

    # convert single-image batches to multiple-image batches
    for idx in range(len(image_batches_)):
        if not isinstance(image_batches_[idx], list):
            image_batches_[idx] = [image_batches_[idx]]

        assert isinstance(image_batches_[idx], list)

    # append all images into a list (as a batch)
112
    images: list[Image.Image] = []
113
114
115
116
117
118
    for image_batch in image_batches_:
        images += image_batch

    # image to pixel values
    image_processor = processor.image_processor

119
120
121
    preprocess_result = image_processor.preprocess(
        images=images, return_tensors="pt"
    ).data
122
123
124
    pixel_values = preprocess_result["pixel_values"]
    image_grid_thw = preprocess_result["image_grid_thw"]

125
    # pixel values to embeddings & grid_thws
126
127
128
    def get_image_embeds(model):
        with torch.no_grad():
            visual = model.visual
129

130
            pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
131
            return visual(pixel_values_on_device, grid_thw=image_grid_thw).cpu()
132
133

    image_embeds = torch.concat(llm.apply_model(get_image_embeds))
134
135

    # split into original batches
136
    result: list[Qwen2VLPromptImageEmbeddingInput] = []
137
138
139
140
141
    image_counter = 0
    embed_counter = 0
    for image_batch in image_batches_:
        cur_batch_image_count = len(image_batch)
        merge_size = image_processor.merge_size
142
143
        cur_batch_embed_len = sum(
            grid_thw.prod(-1) // merge_size // merge_size
144
145
146
147
            for grid_thw in image_grid_thw[
                image_counter : image_counter + cur_batch_image_count
            ]
        )
148

149
150
151
152
153
154
155
156
157
158
        result.append(
            {
                "image_embeds": image_embeds[
                    embed_counter : embed_counter + cur_batch_embed_len
                ],
                "image_grid_thw": image_grid_thw[
                    image_counter : image_counter + cur_batch_image_count
                ],
            }
        )
159
160
161
162

        embed_counter += cur_batch_embed_len
        image_counter += cur_batch_image_count

163
    # ensure we don't lose any images or embeddings
164
165
166
167
168
169
170
171
    assert embed_counter == image_embeds.size(0)
    assert image_counter == image_grid_thw.size(0)
    assert len(image_batches) == len(result)

    return result


def batch_make_video_embeddings(
172
173
    video_batches: PromptVideoInput, processor, llm: VllmRunner
) -> list[Qwen2VLPromptVideoEmbeddingInput]:
174
175
176
177
    """batched video embeddings for Qwen2-VL

    A NDArray represents a single video's all frames.

178
    This will infer all videos' embeddings in a single batch,
179
180
181
      and split the result according to input batches.

    video_batches:
182
183
      - Single-video batches: `list[NDArray]`
      - Multiple-video batches: `list[list[NDArray]]`
184
185
    """

186
    video_batches_: list[Any] = video_batches[:]
187
188
189

    for idx in range(len(video_batches_)):
        if not isinstance(video_batches_[idx], list):
190
            single_video_batch: list[npt.NDArray] = [video_batches_[idx]]
191
192
193
194
195
            video_batches_[idx] = single_video_batch

        assert isinstance(video_batches_[idx], list)

    # append all videos into a list (as a batch)
196
    videos: list[npt.NDArray] = []
197
198
199
200
    for video_batch in video_batches_:
        videos += video_batch

    # video to pixel values
201
    video_processor = processor.video_processor
202

203
204
    preprocess_result = video_processor.preprocess(
        videos=videos, return_tensors="pt"
205
    ).data
206
207
208
    pixel_values = preprocess_result["pixel_values_videos"]
    video_grid_thw = preprocess_result["video_grid_thw"]

209
    # pixel values to embeddings & grid_thws
210
211
212
213
    def get_image_embeds(model):
        with torch.no_grad():
            visual = model.visual

214
            pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype)
215
            return visual(pixel_values_on_device, grid_thw=video_grid_thw).cpu()
216

217
    video_embeds = torch.concat(llm.apply_model(get_image_embeds))
218
219

    # split into original batches
220
    result: list[Qwen2VLPromptVideoEmbeddingInput] = []
221
222
223
224
    video_counter = 0
    embed_counter = 0
    for video_batch in video_batches_:
        cur_batch_video_count = len(video_batch)
225
        merge_size = video_processor.merge_size
226
227
        cur_batch_embed_len = sum(
            grid_thw.prod(-1) // merge_size // merge_size
228
229
230
231
            for grid_thw in video_grid_thw[
                video_counter : video_counter + cur_batch_video_count
            ]
        )
232

233
234
235
236
237
238
239
240
241
242
        result.append(
            {
                "video_embeds": video_embeds[
                    embed_counter : embed_counter + cur_batch_embed_len
                ],
                "video_grid_thw": video_grid_thw[
                    video_counter : video_counter + cur_batch_video_count
                ],
            }
        )
243
244
245
246

        embed_counter += cur_batch_embed_len
        video_counter += cur_batch_video_count

247
    # ensure we don't lose any videos or embeddings
248
249
250
251
252
253
254
    assert embed_counter == video_embeds.size(0)
    assert video_counter == video_grid_thw.size(0)
    assert len(video_batches) == len(result)

    return result


255
def run_embedding_input_test(
256
257
    vllm_runner: type[VllmRunner],
    inputs: list[tuple[list[str], PromptImageInput, PromptVideoInput]],
258
259
260
261
262
263
264
    model: str,
    *,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    mm_limit: int,
    tensor_parallel_size: int,
265
    distributed_executor_backend: str | None = None,
266
267
268
269
):
    """Inference result should be the same between
    original image/video input and image/video embeddings input.
    """
270
    from transformers import AutoProcessor
271
272
273
274

    processor = AutoProcessor.from_pretrained(model)

    # max_model_len should be greater than image_feature_size
275
    with vllm_runner(
276
277
278
279
280
281
282
283
284
        model,
        runner="generate",
        max_model_len=4000,
        max_num_seqs=3,
        dtype=dtype,
        limit_mm_per_prompt={"image": mm_limit, "video": mm_limit},
        tensor_parallel_size=tensor_parallel_size,
        distributed_executor_backend=distributed_executor_backend,
        default_torch_num_threads=1,
285
        enable_mm_embeds=True,
286
    ) as vllm_model:
287
        outputs_per_case_for_original_input = [
288
289
290
291
292
293
294
            vllm_model.generate_greedy_logprobs(
                prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                images=images or None,
                videos=videos or None,
            )
295
296
297
298
299
300
301
302
            for prompts, images, videos in inputs
        ]

        outputs_per_case_for_embeddings_input = [
            vllm_model.generate_greedy_logprobs(
                prompts,
                max_tokens,
                num_logprobs=num_logprobs,
303
304
305
306
307
308
309
                images=batch_make_image_embeddings(images, processor, vllm_model)
                if images
                else None,
                videos=batch_make_video_embeddings(videos, processor, vllm_model)
                if videos
                else None,
            )
310
311
312
            for prompts, images, videos in inputs
        ]

313
314
315
    for outputs_for_original_input, outputs_for_embeddings_input in zip(
        outputs_per_case_for_original_input, outputs_per_case_for_embeddings_input
    ):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        check_logprobs_close(
            outputs_0_lst=outputs_for_original_input,
            outputs_1_lst=outputs_for_embeddings_input,
            name_0="original_input",
            name_1="embeddings_input",
        )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.5, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
340
341
342
343
344
345
346
347
348
349
def test_qwen2_vl_image_embeddings_input(
    vllm_runner,
    image_assets,
    model,
    size_factors,
    dtype,
    max_tokens,
    num_logprobs,
    monkeypatch,
) -> None:
350
351
    images = [asset.pil_image for asset in image_assets]

352
353
    inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
        (
354
355
356
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
            [],
357
358
359
        )
        for image, prompt in zip(images, IMAGE_PROMPTS)
    ]
360

361
    run_embedding_input_test(
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.5, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
389
390
391
392
393
394
395
396
397
def test_qwen2_vl_multiple_image_embeddings_input(
    vllm_runner,
    image_assets,
    model,
    size_factors,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
398
399
    images = [asset.pil_image for asset in image_assets]

400
401
402
403
404
405
406
407
408
409
    inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
        (
            [MULTIIMAGE_PROMPT for _ in size_factors],
            [
                [rescale_image_size(image, factor) for image in images]
                for factor in size_factors
            ],
            [],
        )
    ]
410

411
    run_embedding_input_test(
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=2,
        tensor_parallel_size=1,
    )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.25, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
439
440
441
442
443
444
445
446
447
def test_qwen2_vl_video_embeddings_input(
    vllm_runner,
    video_assets,
    model,
    size_factors,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
448
449
450
451
452
453
    num_frames = 4
    sampled_vids = [
        sample_frames_from_video(asset.np_ndarrays, num_frames)
        for asset in video_assets
    ]

454
455
    inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [
        (
456
457
458
            [prompt for _ in size_factors],
            [],
            [rescale_video_size(video, factor) for factor in size_factors],
459
460
461
        )
        for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)
    ]
462

463
464
465
466
467
468
469
470
471
472
    run_embedding_input_test(
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )