h2ovl.py 18.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
10
11
from collections.abc import Mapping, Sequence
from typing import Optional
12
13
14
15
16

import torch
from PIL import Image
from transformers import PretrainedConfig

17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
20
21
22
23
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
24
                                        PromptUpdate, PromptUpdateDetails)
25
26
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer
27
28

from .intern_vit import InternVisionModel
29
30
31
32
33
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
                       BaseInternVLProcessingInfo, BaseInternVLProcessor,
                       InternVLChatModel, InternVLDummyInputsBuilder,
                       InternVLMultiModalProcessor, build_transform,
                       find_closest_aspect_ratio, get_internvl_target_ratios)
34

35
logger = init_logger(__name__)
36

37
38
39
40
41
42

def resolve_h2ovl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
43
    use_thumbnail: bool,
44
) -> tuple[int, int]:
45
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
46
47
48
49
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1
50

51
52
53
54
55
56
57
58
59
60
    return min_dynamic_patch, max_dynamic_patch


def get_h2ovl_target_ratios(
    min_num: int,
    max_num: int,
    *,
    prior_aspect_ratio: Optional[tuple[int, int]],
) -> list[tuple[int, int]]:
    target_ratios = get_internvl_target_ratios(min_num, max_num)
61
62
63
64
65
66
67
68

    # if prior_aspect_ratio is provided, filter the target ratios
    if prior_aspect_ratio is not None:
        target_ratios = [
            ratio for ratio in target_ratios if prior_aspect_ratio[0] %
            ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
        ]

69
70
71
72
73
74
75
76
77
78
79
80
81
82
    return target_ratios


# modified to include blocks generated in second pass
def calculate_h2ovl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int, tuple[int, int]]:
    aspect_ratio = orig_width / orig_height

83
    # find the closest aspect ratio to the target
84
85
86
87
88
89
90
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
91
92
93
94
95

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
96
97
98

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
99
        blocks += 1
100

101
102
103
104
    return blocks, target_width, target_height, target_aspect_ratio


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
105
106
# refactored to handle prior_aspect_ratio
def dynamic_preprocess_h2ovl(
107
    image: Image.Image,
108
109
    *,
    target_ratios: list[tuple[int, int]],
110
111
    image_size: int,
    use_thumbnail: bool,
112
) -> tuple[list[Image.Image], tuple[int, int]]:
113
114
    orig_width, orig_height = image.size

115
116
117
118
119
120
121
122
123
124
125
126
127
128
    # calculate the number of blocks without thumbnail
    (
        blocks,
        target_width,
        target_height,
        target_aspect_ratio,
    ) = calculate_h2ovl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

129
130
131
132
133
134
135
136
137
138
139
140
141
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
142

143
    assert len(processed_images) == blocks
144

145
146
147
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
148

149
150
151
    return processed_images, target_aspect_ratio


152
def _preprocess_image(
153
    image: Image.Image,
154
155
156
157
158
159
160
161
162
163
164
165
166
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
    prior_aspect_ratio: Optional[tuple[int, int]],
) -> tuple[torch.Tensor, tuple[int, int]]:
    target_ratios = get_h2ovl_target_ratios(
        min_num,
        max_num,
        prior_aspect_ratio=prior_aspect_ratio,
    )

167
    transform = build_transform(input_size=input_size)
168
    images, target_aspect_ratio = dynamic_preprocess_h2ovl(
169
170
171
        image,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
172
        target_ratios=target_ratios,
173
    )
174
175

    pixel_values = torch.stack([transform(image) for image in images])
176
177
178
    return pixel_values, target_aspect_ratio


179
180
# refactored to use the _preprocess_image function
def image_to_pixel_values_h2ovl(
181
    image: Image.Image,
182
    *,
183
184
185
186
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
187
    use_msac: bool,
188
189
) -> torch.Tensor:
    # when MSAC is turned on, we need to process the image twice
190
    if use_msac:
191
        # first pass
192
        pixel_values1, aspect_ratio1 = _preprocess_image(
193
194
            image,
            input_size=input_size,
195
            min_num=1,
196
197
            max_num=max_num,
            use_thumbnail=True,
198
            prior_aspect_ratio=None,
199
200
        )
        # second pass
201
        pixel_values2, _ = _preprocess_image(
202
203
            image,
            input_size=input_size,
204
            min_num=3,
205
            max_num=max_num,
206
207
            use_thumbnail=True,
            prior_aspect_ratio=aspect_ratio1,
208
209
210
        )
        # combine pixel values
        pixel_values = torch.cat(
211
            [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
212
213

    else:
214
        pixel_values, _ = _preprocess_image(
215
216
217
218
219
            image,
            input_size=input_size,
            min_num=min_num,
            max_num=max_num,
            use_thumbnail=use_thumbnail,
220
            prior_aspect_ratio=None,
221
222
223
224
225
        )

    return pixel_values


226
class H2OVLProcessor(BaseInternVLProcessor):
227

228
229
230
231
232
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
233
        min_dynamic_patch: Optional[int] = None,
234
235
236
237
238
239
240
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_msac: Optional[bool] = None,
    ) -> None:
        super().__init__(
            config,
            tokenizer,
241
            min_dynamic_patch=min_dynamic_patch,
242
243
244
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
245

246
247
248
        if use_msac is None:
            use_msac = config.use_msac
        assert isinstance(use_msac, bool)
249

250
        self.use_msac = use_msac
251

252
253
254
    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]
255

256
257
258
259
260
261
    def get_image_repl_features(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        return IMG_CONTEXT * feature_size
262

263
264
265
266
267
268
269
    def get_image_repl_full(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        features = self.get_image_repl_features(feature_size, num_patches)
        return IMG_START + features + IMG_END
270

271
    def resolve_min_max_num(
272
273
        self,
        *,
274
        min_dynamic_patch: Optional[int] = None,
275
        max_dynamic_patch: Optional[int] = None,
276
277
278
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
279
280
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
281
282
283
284
285
286
287
288
289
290
291
292
293
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_h2ovl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
294

295
296
297
    def resolve_target_ratios(
        self,
        *,
298
        min_dynamic_patch: Optional[int] = None,
299
300
301
302
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
        prior_aspect_ratio: Optional[tuple[int, int]] = None,
303
        override_min_num: Optional[int] = None,
304
305
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
306
            min_dynamic_patch=min_dynamic_patch,
307
308
309
310
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
311
312
        if override_min_num is not None:
            min_num = override_min_num
313

314
315
316
317
        return get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=prior_aspect_ratio,
318
319
        )

320
321
322
323
324
325
326
327
328
329
330
331
332
333
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        use_msac: Optional[bool] = None,
    ) -> int:
        use_msac = (self.use_msac if use_msac is None else use_msac)

        use_thumbnail = self.use_thumbnail

        if use_msac:
            target_ratios_1 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
334
                override_min_num=1,
335
336
337
338
339
340
341
342
343
344
345
346
            )
            num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_1,
                use_thumbnail=True,
            )

            target_ratios_2 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
                prior_aspect_ratio=aspect_ratio_1,
347
                override_min_num=3,
348
349
350
351
352
353
354
355
356
357
            )
            num_patches_2, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios_2,
                use_thumbnail=True,
            )

            num_patches = num_patches_1 + num_patches_2 - 1
358
        else:
359
360
361
362
363
364
365
366
367
368
369
370
            target_ratios = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
            )
            num_patches, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
                image_size=self.image_size,
                target_ratios=target_ratios,
                use_thumbnail=use_thumbnail,
            )

        return num_patches * self.num_image_token
371

372
373
374
    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
375
        min_dynamic_patch: Optional[int] = None,
376
377
378
379
380
381
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        use_msac = self.use_msac if len(images) == 1 else False

        min_num, max_num = self.resolve_min_max_num(
382
            min_dynamic_patch=min_dynamic_patch,
383
384
385
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
386
387
        )

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        return [
            image_to_pixel_values_h2ovl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
                use_msac=use_msac,
            ) for image in images
        ]


class H2OVLProcessingInfo(BaseInternVLProcessingInfo):

    def get_hf_processor(
403
404
        self,
        *,
405
        min_dynamic_patch: Optional[int] = None,
406
        max_dynamic_patch: Optional[int] = None,
407
        dynamic_image_size: Optional[bool] = None,
408
        **kwargs: object,
409
    ) -> H2OVLProcessor:
410
411
412
413
414
415
416
417
418
419
420
421
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            H2OVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        )

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
        if mm_counts.get("image", 0) <= 1:
            max_tokens_per_image = max_tokens_one_image
        else:
            max_tokens_per_image = self.get_max_image_tokens(use_msac=False)

        return {"image": max_tokens_per_image}
436

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[H2OVLProcessor],
        use_msac: Optional[bool] = None,
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
            use_msac=use_msac,
        )
453

454
455
    def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
456

457
458
459
460
461
462
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=None,
            use_msac=use_msac,
        )
463
464


465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
                               ):

    def __init__(self,
                 info: H2OVLProcessingInfo,
                 dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
        super().__init__(
            info,
            dummy_inputs,
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

481
482
        mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
        if self.cache is not None and mm_limit["image"] >= 2:
483
484
485
486
487
            # The processor output depends on the number of images passed,
            # making it incompatible with processing cache which is supposed
            # to be invariant of how many images are passed per prompt
            self.cache = None
            logger.warning_once(
488
489
                f"{type(self).__name__} does not support processing cache with "
                "multi-image support enabled.")
490

491
    def _get_prompt_updates(
492
493
494
495
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
496
    ) -> Sequence[PromptUpdate]:
497
498
499
500
501
502
503
504
505
506
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
507
        else:
508
509
510
            image_num_patches = []

        num_images = len(image_num_patches)
511

512
513
514
        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                    use_msac=None if num_images == 1 else False,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

531
            return PromptUpdateDetails(
532
533
534
535
536
                full=hf_processor.get_image_repl_full(feature_size,
                                                      num_patches),
                features=hf_processor.get_image_repl_features(
                    feature_size, num_patches),
            )
537

538
539
540
541
542
543
544
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
545
546


547
548
549
550
@MULTIMODAL_REGISTRY.register_processor(
    H2OVLMultiModalProcessor,
    info=H2OVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
class H2OVLChatModel(InternVLChatModel):

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
                num_hidden_layers = (config.vision_config.num_hidden_layers +
                                     vision_feature_layer + 1)
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to H2OVL"
            raise NotImplementedError(msg)