h2ovl.py 17.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
11
from collections.abc import Mapping, Sequence
12
13
14
15
16
17

import torch
from PIL import Image
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict
20
21
22
23
24
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
25
from vllm.multimodal.processing.processor import (
26
27
28
29
30
    MultiModalProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
31
from vllm.tokenizers import TokenizerLike
32
33

from .intern_vit import InternVisionModel
34
35
36
37
38
39
40
41
42
43
44
45
46
from .internvl import (
    IMG_CONTEXT,
    IMG_END,
    IMG_START,
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    BaseInternVLProcessor,
    InternVLChatModel,
    build_transform,
    find_closest_aspect_ratio,
    get_internvl_target_ratios,
)
47

48
49
50
51
52
53

def resolve_h2ovl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
54
    use_thumbnail: bool,
55
) -> tuple[int, int]:
56
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
57
58
59
60
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1
61

62
63
64
65
66
67
68
    return min_dynamic_patch, max_dynamic_patch


def get_h2ovl_target_ratios(
    min_num: int,
    max_num: int,
    *,
69
    prior_aspect_ratio: tuple[int, int] | None,
70
71
) -> list[tuple[int, int]]:
    target_ratios = get_internvl_target_ratios(min_num, max_num)
72
73
74
75

    # if prior_aspect_ratio is provided, filter the target ratios
    if prior_aspect_ratio is not None:
        target_ratios = [
76
77
78
79
            ratio
            for ratio in target_ratios
            if prior_aspect_ratio[0] % ratio[0] != 0
            and prior_aspect_ratio[1] % ratio[1] != 0
80
81
        ]

82
83
84
85
86
87
88
89
90
91
92
93
94
95
    return target_ratios


# modified to include blocks generated in second pass
def calculate_h2ovl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int, tuple[int, int]]:
    aspect_ratio = orig_width / orig_height

96
    # find the closest aspect ratio to the target
97
98
99
100
101
102
103
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
104
105
106
107
108

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
109
110
111

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
112
        blocks += 1
113

114
115
116
117
    return blocks, target_width, target_height, target_aspect_ratio


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
118
119
# refactored to handle prior_aspect_ratio
def dynamic_preprocess_h2ovl(
120
    image: Image.Image,
121
122
    *,
    target_ratios: list[tuple[int, int]],
123
124
    image_size: int,
    use_thumbnail: bool,
125
) -> tuple[list[Image.Image], tuple[int, int]]:
126
127
    orig_width, orig_height = image.size

128
129
130
131
132
133
134
135
136
137
138
139
140
141
    # calculate the number of blocks without thumbnail
    (
        blocks,
        target_width,
        target_height,
        target_aspect_ratio,
    ) = calculate_h2ovl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

142
143
144
145
146
147
148
149
150
151
152
153
154
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
155

156
    assert len(processed_images) == blocks
157

158
159
160
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
161

162
163
164
    return processed_images, target_aspect_ratio


165
def _preprocess_image(
166
    image: Image.Image,
167
168
169
170
171
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
172
    prior_aspect_ratio: tuple[int, int] | None,
173
174
175
176
177
178
179
) -> tuple[torch.Tensor, tuple[int, int]]:
    target_ratios = get_h2ovl_target_ratios(
        min_num,
        max_num,
        prior_aspect_ratio=prior_aspect_ratio,
    )

180
    transform = build_transform(input_size=input_size)
181
    images, target_aspect_ratio = dynamic_preprocess_h2ovl(
182
183
184
        image,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
185
        target_ratios=target_ratios,
186
    )
187
188

    pixel_values = torch.stack([transform(image) for image in images])
189
190
191
    return pixel_values, target_aspect_ratio


192
193
# refactored to use the _preprocess_image function
def image_to_pixel_values_h2ovl(
194
    image: Image.Image,
195
    *,
196
197
198
199
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
200
    use_msac: bool,
201
202
) -> torch.Tensor:
    # when MSAC is turned on, we need to process the image twice
203
    if use_msac:
204
        # first pass
205
        pixel_values1, aspect_ratio1 = _preprocess_image(
206
207
            image,
            input_size=input_size,
208
            min_num=1,
209
210
            max_num=max_num,
            use_thumbnail=True,
211
            prior_aspect_ratio=None,
212
213
        )
        # second pass
214
        pixel_values2, _ = _preprocess_image(
215
216
            image,
            input_size=input_size,
217
            min_num=3,
218
            max_num=max_num,
219
220
            use_thumbnail=True,
            prior_aspect_ratio=aspect_ratio1,
221
222
223
        )
        # combine pixel values
        pixel_values = torch.cat(
224
225
            [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0
        )
226
227

    else:
228
        pixel_values, _ = _preprocess_image(
229
230
231
232
233
            image,
            input_size=input_size,
            min_num=min_num,
            max_num=max_num,
            use_thumbnail=use_thumbnail,
234
            prior_aspect_ratio=None,
235
236
237
238
239
        )

    return pixel_values


240
241
242
243
class H2OVLProcessor(BaseInternVLProcessor):
    def __init__(
        self,
        config: PretrainedConfig,
244
        tokenizer: TokenizerLike,
245
        *,
246
247
248
249
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_msac: bool | None = None,
250
251
252
253
    ) -> None:
        super().__init__(
            config,
            tokenizer,
254
            min_dynamic_patch=min_dynamic_patch,
255
256
257
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
258

259
260
261
        if use_msac is None:
            use_msac = config.use_msac
        assert isinstance(use_msac, bool)
262

263
        self.use_msac = use_msac
264

265
266
267
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]
268

269
    def get_image_repl(
270
271
        self,
        feature_size: int,
272
        num_patches: int | None,
273
274
275
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
276

277
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
278

279
    def resolve_min_max_num(
280
281
        self,
        *,
282
283
284
285
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
286
    ) -> tuple[int, int]:
287
288
289
290
291
292
293
294
295
296
297
298
        min_dynamic_patch = (
            self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
        )
        max_dynamic_patch = (
            self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
        )
        dynamic_image_size = (
            self.dynamic_image_size
            if dynamic_image_size is None
            else dynamic_image_size
        )
        use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
299
300
301
302
303
304
305

        return resolve_h2ovl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
306

307
308
309
    def resolve_target_ratios(
        self,
        *,
310
311
312
313
314
315
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
        prior_aspect_ratio: tuple[int, int] | None = None,
        override_min_num: int | None = None,
316
317
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
318
            min_dynamic_patch=min_dynamic_patch,
319
320
321
322
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
323
324
        if override_min_num is not None:
            min_num = override_min_num
325

326
327
328
329
        return get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=prior_aspect_ratio,
330
331
        )

332
333
334
335
336
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
337
        use_msac: bool | None = None,
338
    ) -> int:
339
        use_msac = self.use_msac if use_msac is None else use_msac
340
341
342
343
344
345

        use_thumbnail = self.use_thumbnail

        if use_msac:
            target_ratios_1 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
346
                override_min_num=1,
347
348
349
350
351
352
353
354
355
356
357
358
            )
            num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_1,
                use_thumbnail=True,
            )

            target_ratios_2 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
                prior_aspect_ratio=aspect_ratio_1,
359
                override_min_num=3,
360
361
362
363
364
365
366
367
368
369
            )
            num_patches_2, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_2,
                use_thumbnail=True,
            )

            num_patches = num_patches_1 + num_patches_2 - 1
370
        else:
371
372
373
374
375
376
377
378
379
380
381
382
            target_ratios = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
            )
            num_patches, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios,
                use_thumbnail=use_thumbnail,
            )

        return num_patches * self.num_image_token
383

384
385
386
    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
387
388
389
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
390
391
392
393
    ) -> list[torch.Tensor]:
        use_msac = self.use_msac if len(images) == 1 else False

        min_num, max_num = self.resolve_min_max_num(
394
            min_dynamic_patch=min_dynamic_patch,
395
396
397
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
398
399
        )

400
401
402
403
404
405
406
407
        return [
            image_to_pixel_values_h2ovl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
                use_msac=use_msac,
408
409
            )
            for image in images
410
411
412
413
        ]


class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
414
    def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor:
415
416
417
418
419
        return self.ctx.init_processor(
            H2OVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
420
421
422
423
424
425
426
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
427
        processor: H2OVLProcessor,
428
        use_msac: bool | None = None,
429
430
431
432
433
434
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
            use_msac=use_msac,
        )
435
436


437
class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
438
    def _get_prompt_updates(
439
440
441
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
442
        out_mm_kwargs: MultiModalKwargsItems,
443
    ) -> Sequence[PromptUpdate]:
444
445
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

446
447
448
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
449
450
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
451
        elif "image_embeds" in out_mm_data:
452
453
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
454
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
455
        else:
456
457
458
            image_num_patches = []

        num_images = len(image_num_patches)
459

460
461
        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
462
463
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                    use_msac=None if num_images == 1 else False,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

480
            return hf_processor.get_image_repl(feature_size, num_patches)
481

482
483
484
485
486
487
488
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
489

490
491
    def _cached_apply_hf_processor(
        self,
492
        prompt: str | list[int],
493
494
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
495
        tokenization_kwargs: Mapping[str, object],
496
        mm_uuids: MultiModalUUIDDict | None = None,
497
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
498
499
500
501
502
        # The processor logic is different for len(images) <= 1 vs > 1
        # Since the processing cache assumes that the processor output is
        # invariant of how many images are passed per prompt, we only
        # perform caching for the most common case
        if mm_data_items.get_count("image", strict=False) > 1:
503
            return self._apply_hf_processor(
504
                prompt=prompt,
505
                mm_data_items=mm_data_items,
506
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
507
                tokenization_kwargs=tokenization_kwargs,
508
                mm_uuids=mm_uuids,
509
510
511
512
513
514
            )

        return super()._cached_apply_hf_processor(
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
515
            tokenization_kwargs=tokenization_kwargs,
516
            mm_uuids=mm_uuids,
517
518
        )

519

520
521
522
@MULTIMODAL_REGISTRY.register_processor(
    H2OVLMultiModalProcessor,
    info=H2OVLProcessingInfo,
523
524
    dummy_inputs=BaseInternVLDummyInputsBuilder,
)
525
526
527
528
class H2OVLChatModel(InternVLChatModel):
    def _init_vision_model(
        self,
        config: PretrainedConfig,
529
        quant_config: QuantizationConfig | None,
530
531
532
533
534
535
536
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
537
538
539
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
540
541
542
543
544
545
546
547
548
549
550
551
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to H2OVL"
            raise NotImplementedError(msg)