test_qwen2_vl.py 14.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Any, Optional, TypedDict, Union
4
5
6
7
8
9

import numpy.typing as npt
import pytest
import torch
from PIL import Image

10
11
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
12
13
14
15
16

from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput,
                          PromptVideoInput, VllmRunner)
from ...utils import check_logprobs_close

17
18
19
20
21
22
23
24
25

@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    V1 Test: batch_make_xxxxx_embeddings calls a V0 internal
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


26
27
28
29
30
models = ["Qwen/Qwen2-VL-2B-Instruct"]
target_dtype = "half"

IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
31
MODEL_HIDDEN_SIZE = 1536
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


def qwen2_vl_chat_template(*query):
    return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n"  # noqa: E501


IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
    "stop_sign":
    qwen2_vl_chat_template(
        IMAGE_PLACEHOLDER,
        "What is the biggest text's content in this image?",
    ),
    "cherry_blossom":
    qwen2_vl_chat_template(
        IMAGE_PLACEHOLDER,
        "What is the season shown in this image? ",
        "Reply with a short sentence (no more than 20 words)",
    ),
})

VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
    "sample_demo_1":
    qwen2_vl_chat_template(
        VIDEO_PLACEHOLDER,
        "Describe this video with a short sentence ",
        "(no more than 20 words)",
    ),
})

MULTIIMAGE_PROMPT = qwen2_vl_chat_template(
    IMAGE_PLACEHOLDER,
    IMAGE_PLACEHOLDER,
    "Describe these two images separately. ",
    "For each image, reply with a short sentence ",
    "(no more than 10 words).",
)


class Qwen2VLPromptImageEmbeddingInput(TypedDict):
    image_embeds: torch.Tensor
    image_grid_thw: torch.Tensor


class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
    video_embeds: torch.Tensor
    video_grid_thw: torch.Tensor


def batch_make_image_embeddings(
81
82
        image_batches: list[Union[Image.Image, list[Image.Image]]], processor,
        llm: VllmRunner) -> list[Qwen2VLPromptImageEmbeddingInput]:
83
84
85
86
87
88
    """batched image embeddings for Qwen2-VL

    This will infer all images' embeddings in a single batch, 
      and split the result according to input batches.

    image_batches:
89
90
      - Single-image batches: `list[Image.Image]`
      - Multiple-image batches: `list[list[Image.Image]]]`
91
    
92
    returns: `list[Qwen2VLPromptImageEmbeddingInput]`
93
94
    """

95
    image_batches_: list[Any] = image_batches[:]
96
97
98
99
100
101
102
103
104

    # convert single-image batches to multiple-image batches
    for idx in range(len(image_batches_)):
        if not isinstance(image_batches_[idx], list):
            image_batches_[idx] = [image_batches_[idx]]

        assert isinstance(image_batches_[idx], list)

    # append all images into a list (as a batch)
105
    images: list[Image.Image] = []
106
107
108
109
110
111
112
113
114
115
116
117
    for image_batch in image_batches_:
        images += image_batch

    # image to pixel values
    image_processor = processor.image_processor

    preprocess_result = image_processor \
        .preprocess(images=images, return_tensors="pt") \
        .data
    pixel_values = preprocess_result["pixel_values"]
    image_grid_thw = preprocess_result["image_grid_thw"]

118
    # pixel values to embeddings & grid_thws
119
120
121
    def get_image_embeds(model):
        with torch.no_grad():
            visual = model.visual
122

123
124
125
126
127
128
129
            pixel_values_on_device = pixel_values.to(visual.device,
                                                     dtype=visual.dtype)
            image_grid_thw_on_device = image_grid_thw.to(visual.device,
                                                         dtype=torch.int64)
            return visual(pixel_values_on_device,
                          grid_thw=image_grid_thw_on_device)

130
    # V1 Test: this calls a V0 internal.
131
    image_embeds = torch.concat(llm.apply_model(get_image_embeds))
132
133

    # split into original batches
134
    result: list[Qwen2VLPromptImageEmbeddingInput] = []
135
136
137
138
139
    image_counter = 0
    embed_counter = 0
    for image_batch in image_batches_:
        cur_batch_image_count = len(image_batch)
        merge_size = image_processor.merge_size
140
141
        cur_batch_embed_len = sum(
            grid_thw.prod(-1) // merge_size // merge_size
142
            for grid_thw in image_grid_thw[image_counter:image_counter +
143
                                           cur_batch_image_count])
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        result.append({
            "image_embeds":
            image_embeds[embed_counter:embed_counter + cur_batch_embed_len],
            "image_grid_thw":
            image_grid_thw[image_counter:image_counter +
                           cur_batch_image_count],
        })

        embed_counter += cur_batch_embed_len
        image_counter += cur_batch_image_count

    # ensure we don't lost any images or embeddings
    assert embed_counter == image_embeds.size(0)
    assert image_counter == image_grid_thw.size(0)
    assert len(image_batches) == len(result)

    return result


def batch_make_video_embeddings(
        video_batches: PromptVideoInput, processor,
166
        llm: VllmRunner) -> list[Qwen2VLPromptVideoEmbeddingInput]:
167
168
169
170
171
172
173
174
    """batched video embeddings for Qwen2-VL

    A NDArray represents a single video's all frames.

    This will infer all videos' embeddings in a single batch, 
      and split the result according to input batches.

    video_batches:
175
176
      - Single-video batches: `list[NDArray]`
      - Multiple-video batches: `list[list[NDArray]]`
177
178
    """

179
    video_batches_: list[Any] = video_batches[:]
180
181
182

    for idx in range(len(video_batches_)):
        if not isinstance(video_batches_[idx], list):
183
            single_video_batch: list[npt.NDArray] = [video_batches_[idx]]
184
185
186
187
188
            video_batches_[idx] = single_video_batch

        assert isinstance(video_batches_[idx], list)

    # append all videos into a list (as a batch)
189
    videos: list[npt.NDArray] = []
190
191
192
193
194
195
196
197
198
199
200
201
    for video_batch in video_batches_:
        videos += video_batch

    # video to pixel values
    image_processor = processor.image_processor

    preprocess_result = image_processor \
        .preprocess(images=None, videos=videos, return_tensors="pt") \
        .data
    pixel_values = preprocess_result["pixel_values_videos"]
    video_grid_thw = preprocess_result["video_grid_thw"]

202
    # pixel values to embeddings & grid_thws
203
204
205
206
207
208
209
210
211
212
    def get_image_embeds(model):
        with torch.no_grad():
            visual = model.visual

            pixel_values_on_device = pixel_values.to(visual.device,
                                                     dtype=visual.dtype)
            video_grid_thw_on_device = video_grid_thw.to(visual.device,
                                                         dtype=torch.int64)
            return visual(pixel_values_on_device,
                          grid_thw=video_grid_thw_on_device)
213

214
    # V1 Test: this calls a V0 internal.
215
    video_embeds = torch.concat(llm.apply_model(get_image_embeds))
216
217

    # split into original batches
218
    result: list[Qwen2VLPromptVideoEmbeddingInput] = []
219
220
221
222
223
    video_counter = 0
    embed_counter = 0
    for video_batch in video_batches_:
        cur_batch_video_count = len(video_batch)
        merge_size = image_processor.merge_size
224
225
        cur_batch_embed_len = sum(
            grid_thw.prod(-1) // merge_size // merge_size
226
            for grid_thw in video_grid_thw[video_counter:video_counter +
227
                                           cur_batch_video_count])
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        result.append({
            "video_embeds":
            video_embeds[embed_counter:embed_counter + cur_batch_embed_len],
            "video_grid_thw":
            video_grid_thw[video_counter:video_counter +
                           cur_batch_video_count],
        })

        embed_counter += cur_batch_embed_len
        video_counter += cur_batch_video_count

    # ensure we don't lost any videos or embeddings
    assert embed_counter == video_embeds.size(0)
    assert video_counter == video_grid_thw.size(0)
    assert len(video_batches) == len(result)

    return result


248
def run_embedding_input_test(
249
250
    vllm_runner: type[VllmRunner],
    inputs: list[tuple[list[str], PromptImageInput, PromptVideoInput]],
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    model: str,
    *,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
    mm_limit: int,
    tensor_parallel_size: int,
    distributed_executor_backend: Optional[str] = None,
):
    """Inference result should be the same between
    original image/video input and image/video embeddings input.
    """
    from transformers import AutoProcessor  # noqa: F401

    processor = AutoProcessor.from_pretrained(model)

    # max_model_len should be greater than image_feature_size
    with vllm_runner(model,
                     task="generate",
                     max_model_len=4000,
                     max_num_seqs=3,
                     dtype=dtype,
                     limit_mm_per_prompt={
                         "image": mm_limit,
                         "video": mm_limit
                     },
                     tensor_parallel_size=tensor_parallel_size,
                     distributed_executor_backend=distributed_executor_backend
                     ) as vllm_model:

        outputs_per_case_for_original_input = [
            vllm_model.generate_greedy_logprobs(prompts,
                                                max_tokens,
                                                num_logprobs=num_logprobs,
                                                images=images or None,
                                                videos=videos or None)
            for prompts, images, videos in inputs
        ]

        outputs_per_case_for_embeddings_input = [
            vllm_model.generate_greedy_logprobs(
                prompts,
                max_tokens,
                num_logprobs=num_logprobs,
                images=batch_make_image_embeddings(
296
                    images, processor, vllm_model) if images else None,
297
                videos=batch_make_video_embeddings(
298
                    videos, processor, vllm_model) if videos else None)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            for prompts, images, videos in inputs
        ]

    for outputs_for_original_input, \
        outputs_for_embeddings_input \
        in zip(outputs_per_case_for_original_input,
            outputs_per_case_for_embeddings_input):
        check_logprobs_close(
            outputs_0_lst=outputs_for_original_input,
            outputs_1_lst=outputs_for_embeddings_input,
            name_0="original_input",
            name_1="embeddings_input",
        )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.5, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
                                         size_factors, dtype: str,
                                         max_tokens: int,
                                         num_logprobs: int) -> None:
    images = [asset.pil_image for asset in image_assets]

336
337
    inputs_per_case: list[tuple[
        list[str], PromptImageInput, PromptVideoInput]] = [(
338
339
340
341
342
            [prompt for _ in size_factors],
            [rescale_image_size(image, factor) for factor in size_factors],
            [],
        ) for image, prompt in zip(images, IMAGE_PROMPTS)]

343
    run_embedding_input_test(
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        [],
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.5, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
                                                  model, size_factors,
                                                  dtype: str, max_tokens: int,
                                                  num_logprobs: int) -> None:
    images = [asset.pil_image for asset in image_assets]

378
    inputs_per_case: list[tuple[list[str], PromptImageInput,
379
380
381
382
383
384
385
386
387
                                PromptVideoInput]] = [(
                                    [MULTIIMAGE_PROMPT for _ in size_factors],
                                    [[
                                        rescale_image_size(image, factor)
                                        for image in images
                                    ] for factor in size_factors],
                                    [],
                                )]

388
    run_embedding_input_test(
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=2,
        tensor_parallel_size=1,
    )


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
    "size_factors",
    [
        # Single-scale
        [0.5],
        # Single-scale, batched
        [0.5, 0.5],
        # Multi-scale
        [0.25, 0.25, 0.5],
    ],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
                                         size_factors, dtype: str,
                                         max_tokens: int,
                                         num_logprobs: int) -> None:
    num_frames = 4
    sampled_vids = [
        sample_frames_from_video(asset.np_ndarrays, num_frames)
        for asset in video_assets
    ]

426
427
    inputs_per_case: list[tuple[
        list[str], PromptImageInput, PromptVideoInput]] = [(
428
429
430
431
432
            [prompt for _ in size_factors],
            [],
            [rescale_video_size(video, factor) for factor in size_factors],
        ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]

433
434
435
436
437
438
439
440
441
442
    run_embedding_input_test(
        vllm_runner,
        inputs_per_case,
        model,
        dtype=dtype,
        max_tokens=max_tokens,
        num_logprobs=num_logprobs,
        mm_limit=1,
        tensor_parallel_size=1,
    )