h2ovl.py 7.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
11
from collections.abc import Mapping, Sequence
12
13
14
15
16

import torch
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.inputs import MultiModalKwargsItems
19
20
21
22
23
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
24
from vllm.multimodal.processing.processor import (
25
    MultiModalProcessingInfo,
26
    ProcessorInputs,
27
28
    PromptReplacement,
    PromptUpdate,
29
    TimingContext,
30
)
31
from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor
32
33

from .intern_vit import InternVisionModel
34
35
36
37
38
39
from .internvl import (
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    InternVLChatModel,
)
40

41
42

class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)
        kwargs.setdefault("use_msac", config.use_msac)

        return H2OVLImageProcessor(**kwargs)

57
    def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor:
58
59
60
61
62
63
64
65
66
67
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        return H2OVLProcessor(
68
            tokenizer=self.get_tokenizer(),
69
70
            image_processor=image_processor,
            image_seq_length=image_seq_length,
71
72
73
74
75
76
77
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
78
        processor: H2OVLProcessor,
79
        use_msac: bool | None = None,
80
81
82
83
84
85
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
            use_msac=use_msac,
        )
86
87


88
class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
89
    def _get_prompt_updates(
90
91
92
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
93
        out_mm_kwargs: MultiModalKwargsItems,
94
    ) -> Sequence[PromptUpdate]:
95
96
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

97
98
99
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
100
101
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
102
        elif "image_embeds" in out_mm_data:
103
104
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
105
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
106
        else:
107
108
109
            image_num_patches = []

        num_images = len(image_num_patches)
110

111
112
        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
113
114
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                    use_msac=None if num_images == 1 else False,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

131
            return hf_processor.get_image_repl(num_patches, num_features=feature_size)
132

133
134
135
136
137
138
139
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
140

141
142
    def _cached_apply_hf_processor(
        self,
143
144
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
145
    ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
146
147
148
149
        # The processor logic is different for len(images) <= 1 vs > 1
        # Since the processing cache assumes that the processor output is
        # invariant of how many images are passed per prompt, we only
        # perform caching for the most common case
150
151
        if inputs.mm_data_items.get_count("image", strict=False) > 1:
            return self._apply_hf_processor(inputs, timing_ctx)
152

153
        return super()._cached_apply_hf_processor(inputs, timing_ctx)
154

155

156
157
158
@MULTIMODAL_REGISTRY.register_processor(
    H2OVLMultiModalProcessor,
    info=H2OVLProcessingInfo,
159
160
    dummy_inputs=BaseInternVLDummyInputsBuilder,
)
161
162
163
164
class H2OVLChatModel(InternVLChatModel):
    def _init_vision_model(
        self,
        config: PretrainedConfig,
165
        quant_config: QuantizationConfig | None,
166
167
168
169
170
171
172
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
173
174
175
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
176
177
178
179
180
181
182
183
184
185
186
187
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            msg = "Monolith mode is not applicable to H2OVL"
            raise NotImplementedError(msg)
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if num_image_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_image_tokens // self.num_image_token
        return num_patches * (self.patch_tokens + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if num_vision_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_vision_tokens // (self.patch_tokens + 1)
        return num_patches * self.num_image_token