internvl.py 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import numpy.typing as npt
import torch
import torchvision.transforms as T
from PIL import Image
15
16
from transformers import BatchFeature, TensorType
from transformers.processing_utils import ProcessorMixin
17
18
19

from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails
20
from vllm.tokenizers.hf import HfTokenizer
21
22
23
24
25
26
27
28

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
29
    return T.Compose(
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        [
            T.Lambda(lambda img: convert_image_mode(img, "RGB")),
            T.Resize(
                (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
            ),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD),
        ]
    )


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch


def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1

    return blocks, target_width, target_height


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
    orig_width, orig_height = image.size

    # calculate the number of blocks without thumbnail
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
    return pixel_values


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def video_to_pixel_values_internvl(
    video: npt.NDArray,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

    transform = build_transform(input_size=input_size)
    frames_list = list[Image.Image]()
    for frame in video:
        pil_frame = dynamic_preprocess_internvl(
            Image.fromarray(frame, mode="RGB"),
            target_ratios=target_ratios,
            image_size=input_size,
            use_thumbnail=use_thumbnail,
        )
        assert len(pil_frame) == 1
        frames_list.extend(pil_frame)

    pixel_values = torch.stack([transform(image) for image in frames_list])
    return pixel_values


218
class InternVLImageProcessor:
219
220
    def __init__(
        self,
221
222
223
224
225
        image_size: int,
        min_dynamic_patch: int,
        max_dynamic_patch: int,
        dynamic_image_size: bool,
        use_thumbnail: bool,
226
227
228
229
230
    ) -> None:
        self.image_size = image_size
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
231
        self.use_thumbnail = use_thumbnail
232
233
234
235
236
237
238
239
240

    def resolve_min_max_num(
        self,
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
    ) -> tuple[int, int]:
241
242
243
244
245
246
247
248
        if min_dynamic_patch is None:
            min_dynamic_patch = self.min_dynamic_patch
        if max_dynamic_patch is None:
            max_dynamic_patch = self.max_dynamic_patch
        if dynamic_image_size is None:
            dynamic_image_size = self.dynamic_image_size
        if use_thumbnail is None:
            use_thumbnail = self.use_thumbnail
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
    ) -> list[torch.Tensor]:
264
265
266
267
268
269
270
271
        if min_dynamic_patch is None:
            min_dynamic_patch = self.min_dynamic_patch
        if max_dynamic_patch is None:
            max_dynamic_patch = self.max_dynamic_patch
        if dynamic_image_size is None:
            dynamic_image_size = self.dynamic_image_size

        min_num, max_num = resolve_internvl_min_max_num(
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            )
            for image in images
        ]

    def __call__(
        self,
291
        images: Image.Image | list[Image.Image],
292
293
294
295
296
297
298
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> BatchFeature:
299
        images_lst = [images] if not isinstance(images, list) else images
300

301
302
        pixel_values_lst = self._images_to_pixel_values_lst(
            images_lst,
303
304
305
306
307
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )

308
309
310
311
312
        image_inputs = {
            "pixel_values_flat": torch.cat(pixel_values_lst),
            "image_num_patches": torch.tensor([len(item) for item in pixel_values_lst]),
        }
        return BatchFeature(image_inputs, tensor_type=return_tensors)
313
314


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class InternVLVideoProcessor:
    def __init__(
        self,
        image_size: int,
    ) -> None:
        self.image_size = image_size

    def _videos_to_pixel_values_lst(
        self,
        videos: list[npt.NDArray],
    ) -> list[torch.Tensor]:
        return [
            video_to_pixel_values_internvl(
                video,
                input_size=self.image_size,
                min_num=1,
                max_num=1,
                use_thumbnail=False,
            )
            for video in videos
        ]

    def __call__(
        self,
        videos: npt.NDArray | list[npt.NDArray],
        *,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> BatchFeature:
        videos_lst = [videos] if not isinstance(videos, list) else videos
345

346
347
348
349
350
351
352
        pixel_values_lst = self._videos_to_pixel_values_lst(videos_lst)

        image_inputs = {
            "pixel_values_flat_video": torch.cat(pixel_values_lst),
            "video_num_patches": torch.tensor([len(item) for item in pixel_values_lst]),
        }
        return BatchFeature(image_inputs, tensor_type=return_tensors)
353

354
355

class InternVLProcessor(ProcessorMixin):
356
    """
357
358
359
360
361
    This model doesn't define its own HF processor,
    so we implement our own one here.

    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
362
363
364
365
366

    Code for video processing is adapted from video example:
    https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
    """

367
368
    attributes = ["image_processor", "tokenizer", "video_processor"]

369
370
    def __init__(
        self,
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        image_processor: InternVLImageProcessor,
        tokenizer: HfTokenizer,
        video_processor: InternVLVideoProcessor | None = None,
        *,
        image_seq_length: int,
        start_image_token: str = "<img>",
        end_image_token: str = "</img>",
        ctx_image_token: str = "<IMG_CONTEXT>",
        ctx_video_token: str | None = None,
    ) -> None:
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.video_processor = video_processor

        self.image_seq_length = image_seq_length
        self.start_image_token = start_image_token
        self.end_image_token = end_image_token
        self.ctx_image_token = ctx_image_token
        self.ctx_video_token = ctx_video_token

        self.start_image_token_id = tokenizer.convert_tokens_to_ids(start_image_token)
        self.end_image_token_id = tokenizer.convert_tokens_to_ids(end_image_token)
        self.ctx_image_token_id = tokenizer.convert_tokens_to_ids(ctx_image_token)
        self.ctx_video_token_id = (
            None
            if ctx_video_token is None
            else tokenizer.convert_tokens_to_ids(ctx_video_token)
        )

    def resolve_target_ratios(
        self,
402
403
404
405
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
406
407
408
        use_thumbnail: bool | None = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.image_processor.resolve_min_max_num(
409
410
411
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
412
            use_thumbnail=use_thumbnail,
413
414
        )

415
        return get_internvl_target_ratios(min_num, max_num)
416

417
    def get_num_image_tokens(
418
        self,
419
420
421
422
423
424
425
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        image_processor = self.image_processor
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
426
427
        )

428
429
430
431
432
433
434
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=image_processor.image_size,
            target_ratios=target_ratios,
            use_thumbnail=image_processor.use_thumbnail,
        )
435

436
437
438
        return num_patches * self.image_seq_length

    def get_image_repl(
439
        self,
440
441
442
443
444
445
446
        num_patches: int | None,
        num_features: int | None = None,
    ) -> PromptUpdateDetails[str]:
        if num_patches is None:
            assert num_features is not None
        else:
            num_features = num_patches * self.image_seq_length
447

448
449
        repl_features = self.ctx_image_token * num_features
        repl_full = self.start_image_token + repl_features + self.end_image_token
450

451
        return PromptUpdateDetails.select_text(repl_full, self.ctx_image_token)
452

453
454
    def get_video_repl(self, num_patches: int) -> PromptUpdateDetails[str]:
        assert self.ctx_video_token is not None
455

456
457
458
459
460
461
462
463
464
465
        repl_features = self.ctx_video_token * self.image_seq_length
        repl_features_with_sep = (
            self.start_image_token + repl_features + self.end_image_token
        )
        # num_patches is equal to num_frames
        repl_full = "".join(
            [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
        )

        return PromptUpdateDetails.select_text(repl_full, self.ctx_video_token)
466
467
468
469
470
471
472
473
474
475
476
477
478

    def __call__(
        self,
        text: str | list[str] | None = None,
        images: Image.Image | list[Image.Image] | None = None,
        videos: npt.NDArray | list[npt.NDArray] | None = None,
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> BatchFeature:
479
480
481
482
483
484
485
486
487
488
489
490
        if images is not None:
            image_inputs = self.image_processor(
                images=images,
                min_dynamic_patch=min_dynamic_patch,
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
                return_tensors=return_tensors,
            )
            image_num_patches = image_inputs["image_num_patches"]
        else:
            image_inputs = {}
            image_num_patches = []
491

492
493
494
        if videos is not None:
            if self.video_processor is None:
                raise ValueError("This model does not support video inputs")
495

496
497
498
499
500
501
502
503
            video_inputs = self.video_processor(
                videos=videos,
                return_tensors=return_tensors,
            )
            video_num_patches = video_inputs["video_num_patches"]
        else:
            video_inputs = {}
            video_num_patches = []
504

505
506
507
        if text is not None:
            if not isinstance(text, list):
                text = [text]
508

509
510
511
512
513
            if image_inputs:
                image_token = "<image>"
                image_index = 0
                processed_text = list[str]()
                replace_strings = list[str]()
514

515
516
                for prompt in text:
                    new_prompt = prompt
517

518
519
520
521
522
                    while image_token in new_prompt:
                        new_prompt = new_prompt.replace(image_token, "<placeholder>", 1)
                        image_repl = self.get_image_repl(image_num_patches[image_index])
                        replace_strings.append(image_repl.full)
                        image_index += 1
523

524
525
526
                    while "<placeholder>" in new_prompt:
                        replace_str = replace_strings.pop(0)
                        new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
527

528
                    processed_text.append(new_prompt)
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                text = processed_text

            if video_inputs:
                video_token = "<video>"
                video_index = 0
                processed_text = list[str]()
                replace_strings = list[str]()

                assert video_token is not None

                for prompt in text:
                    new_prompt = prompt

                    while video_token in new_prompt:
                        new_prompt = new_prompt.replace(video_token, "<placeholder>", 1)
                        video_repl = self.get_video_repl(video_num_patches[video_index])
                        replace_strings.append(video_repl.full)
                        video_index += 1

                    while "<placeholder>" in new_prompt:
                        replace_str = replace_strings.pop(0)
                        new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
552

553
554
555
556
557
558
559
560
561
562
563
                    processed_text.append(new_prompt)

                text = processed_text

            text_inputs = self.tokenizer(text, return_tensors=return_tensors)
        else:
            text_inputs = {}

        combined_outputs = {**text_inputs, **image_inputs, **video_inputs}

        return BatchFeature(combined_outputs, tensor_type=return_tensors)