h2ovl.py 17.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Mapping, Sequence
11
from typing import Optional, Union
12
13
14
15
16
17

import torch
from PIL import Image
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
18
19
20
21
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
22
23
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
24
from vllm.transformers_utils.tokenizer import AnyTokenizer
25
26

from .intern_vit import InternVisionModel
27
28
29
30
31
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
                       BaseInternVLProcessingInfo, BaseInternVLProcessor,
                       InternVLChatModel, InternVLDummyInputsBuilder,
                       InternVLMultiModalProcessor, build_transform,
                       find_closest_aspect_ratio, get_internvl_target_ratios)
32

33
34
35
36
37
38

def resolve_h2ovl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
39
    use_thumbnail: bool,
40
) -> tuple[int, int]:
41
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
42
43
44
45
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1
46

47
48
49
50
51
52
53
54
55
56
    return min_dynamic_patch, max_dynamic_patch


def get_h2ovl_target_ratios(
    min_num: int,
    max_num: int,
    *,
    prior_aspect_ratio: Optional[tuple[int, int]],
) -> list[tuple[int, int]]:
    target_ratios = get_internvl_target_ratios(min_num, max_num)
57
58
59
60
61
62
63
64

    # if prior_aspect_ratio is provided, filter the target ratios
    if prior_aspect_ratio is not None:
        target_ratios = [
            ratio for ratio in target_ratios if prior_aspect_ratio[0] %
            ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
        ]

65
66
67
68
69
70
71
72
73
74
75
76
77
78
    return target_ratios


# modified to include blocks generated in second pass
def calculate_h2ovl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int, tuple[int, int]]:
    aspect_ratio = orig_width / orig_height

79
    # find the closest aspect ratio to the target
80
81
82
83
84
85
86
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
87
88
89
90
91

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
92
93
94

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
95
        blocks += 1
96

97
98
99
100
    return blocks, target_width, target_height, target_aspect_ratio


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
101
102
# refactored to handle prior_aspect_ratio
def dynamic_preprocess_h2ovl(
103
    image: Image.Image,
104
105
    *,
    target_ratios: list[tuple[int, int]],
106
107
    image_size: int,
    use_thumbnail: bool,
108
) -> tuple[list[Image.Image], tuple[int, int]]:
109
110
    orig_width, orig_height = image.size

111
112
113
114
115
116
117
118
119
120
121
122
123
124
    # calculate the number of blocks without thumbnail
    (
        blocks,
        target_width,
        target_height,
        target_aspect_ratio,
    ) = calculate_h2ovl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

125
126
127
128
129
130
131
132
133
134
135
136
137
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
138

139
    assert len(processed_images) == blocks
140

141
142
143
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
144

145
146
147
    return processed_images, target_aspect_ratio


148
def _preprocess_image(
149
    image: Image.Image,
150
151
152
153
154
155
156
157
158
159
160
161
162
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
    prior_aspect_ratio: Optional[tuple[int, int]],
) -> tuple[torch.Tensor, tuple[int, int]]:
    target_ratios = get_h2ovl_target_ratios(
        min_num,
        max_num,
        prior_aspect_ratio=prior_aspect_ratio,
    )

163
    transform = build_transform(input_size=input_size)
164
    images, target_aspect_ratio = dynamic_preprocess_h2ovl(
165
166
167
        image,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
168
        target_ratios=target_ratios,
169
    )
170
171

    pixel_values = torch.stack([transform(image) for image in images])
172
173
174
    return pixel_values, target_aspect_ratio


175
176
# refactored to use the _preprocess_image function
def image_to_pixel_values_h2ovl(
177
    image: Image.Image,
178
    *,
179
180
181
182
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
183
    use_msac: bool,
184
185
) -> torch.Tensor:
    # when MSAC is turned on, we need to process the image twice
186
    if use_msac:
187
        # first pass
188
        pixel_values1, aspect_ratio1 = _preprocess_image(
189
190
            image,
            input_size=input_size,
191
            min_num=1,
192
193
            max_num=max_num,
            use_thumbnail=True,
194
            prior_aspect_ratio=None,
195
196
        )
        # second pass
197
        pixel_values2, _ = _preprocess_image(
198
199
            image,
            input_size=input_size,
200
            min_num=3,
201
            max_num=max_num,
202
203
            use_thumbnail=True,
            prior_aspect_ratio=aspect_ratio1,
204
205
206
        )
        # combine pixel values
        pixel_values = torch.cat(
207
            [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
208
209

    else:
210
        pixel_values, _ = _preprocess_image(
211
212
213
214
215
            image,
            input_size=input_size,
            min_num=min_num,
            max_num=max_num,
            use_thumbnail=use_thumbnail,
216
            prior_aspect_ratio=None,
217
218
219
220
221
        )

    return pixel_values


222
class H2OVLProcessor(BaseInternVLProcessor):
223

224
225
226
227
228
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
229
        min_dynamic_patch: Optional[int] = None,
230
231
232
233
234
235
236
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_msac: Optional[bool] = None,
    ) -> None:
        super().__init__(
            config,
            tokenizer,
237
            min_dynamic_patch=min_dynamic_patch,
238
239
240
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
241

242
243
244
        if use_msac is None:
            use_msac = config.use_msac
        assert isinstance(use_msac, bool)
245

246
        self.use_msac = use_msac
247

248
249
250
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]
251

252
    def get_image_repl(
253
254
255
        self,
        feature_size: int,
        num_patches: Optional[int],
256
257
258
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
259

260
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
261

262
    def resolve_min_max_num(
263
264
        self,
        *,
265
        min_dynamic_patch: Optional[int] = None,
266
        max_dynamic_patch: Optional[int] = None,
267
268
269
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
270
271
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
272
273
274
275
276
277
278
279
280
281
282
283
284
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_h2ovl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
285

286
287
288
    def resolve_target_ratios(
        self,
        *,
289
        min_dynamic_patch: Optional[int] = None,
290
291
292
293
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
        prior_aspect_ratio: Optional[tuple[int, int]] = None,
294
        override_min_num: Optional[int] = None,
295
296
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
297
            min_dynamic_patch=min_dynamic_patch,
298
299
300
301
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
302
303
        if override_min_num is not None:
            min_num = override_min_num
304

305
306
307
308
        return get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=prior_aspect_ratio,
309
310
        )

311
312
313
314
315
316
317
318
319
320
321
322
323
324
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        use_msac: Optional[bool] = None,
    ) -> int:
        use_msac = (self.use_msac if use_msac is None else use_msac)

        use_thumbnail = self.use_thumbnail

        if use_msac:
            target_ratios_1 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
325
                override_min_num=1,
326
327
328
329
330
331
332
333
334
335
336
337
            )
            num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_1,
                use_thumbnail=True,
            )

            target_ratios_2 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
                prior_aspect_ratio=aspect_ratio_1,
338
                override_min_num=3,
339
340
341
342
343
344
345
346
347
348
            )
            num_patches_2, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_2,
                use_thumbnail=True,
            )

            num_patches = num_patches_1 + num_patches_2 - 1
349
        else:
350
351
352
353
354
355
356
357
358
359
360
361
            target_ratios = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
            )
            num_patches, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios,
                use_thumbnail=use_thumbnail,
            )

        return num_patches * self.num_image_token
362

363
364
365
    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
366
        min_dynamic_patch: Optional[int] = None,
367
368
369
370
371
372
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        use_msac = self.use_msac if len(images) == 1 else False

        min_num, max_num = self.resolve_min_max_num(
373
            min_dynamic_patch=min_dynamic_patch,
374
375
376
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
377
378
        )

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        return [
            image_to_pixel_values_h2ovl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
                use_msac=use_msac,
            ) for image in images
        ]


class H2OVLProcessingInfo(BaseInternVLProcessingInfo):

    def get_hf_processor(
394
395
        self,
        *,
396
        min_dynamic_patch: Optional[int] = None,
397
        max_dynamic_patch: Optional[int] = None,
398
        dynamic_image_size: Optional[bool] = None,
399
        **kwargs: object,
400
    ) -> H2OVLProcessor:
401
402
403
404
405
406
407
408
409
410
411
412
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            H2OVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[H2OVLProcessor],
        use_msac: Optional[bool] = None,
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
            use_msac=use_msac,
        )
431
432


433
434
435
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
                               ):

436
    def _get_prompt_updates(
437
438
439
440
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
441
    ) -> Sequence[PromptUpdate]:
442
443
444
445
446
447
448
449
450
451
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
452
        else:
453
454
455
            image_num_patches = []

        num_images = len(image_num_patches)
456

457
458
459
        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))
460

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                    use_msac=None if num_images == 1 else False,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

476
            return hf_processor.get_image_repl(feature_size, num_patches)
477

478
479
480
481
482
483
484
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
485

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def _cached_apply_hf_processor(
        self,
        prompt: Union[str, list[int]],
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs, bool]:
        # The processor logic is different for len(images) <= 1 vs > 1
        # Since the processing cache assumes that the processor output is
        # invariant of how many images are passed per prompt, we only
        # perform caching for the most common case
        if mm_data_items.get_count("image", strict=False) > 1:
            # This code path corresponds to the cache being disabled
            return self._apply_hf_processor_main(
                prompt=prompt,
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                enable_hf_prompt_update=True,
            )

        return super()._cached_apply_hf_processor(
            prompt=prompt,
            mm_data_items=mm_data_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

511

512
513
514
515
@MULTIMODAL_REGISTRY.register_processor(
    H2OVLMultiModalProcessor,
    info=H2OVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
class H2OVLChatModel(InternVLChatModel):

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
                num_hidden_layers = (config.vision_config.num_hidden_layers +
                                     vision_feature_layer + 1)
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to H2OVL"
            raise NotImplementedError(msg)