h2ovl.py 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
import torch
from PIL import Image

14
from vllm.tokenizers.hf import HfTokenizer
15
16

from .internvl import (
17
18
    InternVLImageProcessor,
    InternVLProcessor,
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    build_transform,
    find_closest_aspect_ratio,
    get_internvl_target_ratios,
)


def resolve_h2ovl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch


def get_h2ovl_target_ratios(
    min_num: int,
    max_num: int,
    *,
    prior_aspect_ratio: tuple[int, int] | None,
) -> list[tuple[int, int]]:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

    # if prior_aspect_ratio is provided, filter the target ratios
    if prior_aspect_ratio is not None:
        target_ratios = [
            ratio
            for ratio in target_ratios
            if prior_aspect_ratio[0] % ratio[0] != 0
            and prior_aspect_ratio[1] % ratio[1] != 0
        ]

    return target_ratios


# modified to include blocks generated in second pass
def calculate_h2ovl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int, tuple[int, int]]:
    aspect_ratio = orig_width / orig_height

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1

    return blocks, target_width, target_height, target_aspect_ratio


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# refactored to handle prior_aspect_ratio
def dynamic_preprocess_h2ovl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[list[Image.Image], tuple[int, int]]:
    orig_width, orig_height = image.size

    # calculate the number of blocks without thumbnail
    (
        blocks,
        target_width,
        target_height,
        target_aspect_ratio,
    ) = calculate_h2ovl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images, target_aspect_ratio


def _preprocess_image(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
    prior_aspect_ratio: tuple[int, int] | None,
) -> tuple[torch.Tensor, tuple[int, int]]:
    target_ratios = get_h2ovl_target_ratios(
        min_num,
        max_num,
        prior_aspect_ratio=prior_aspect_ratio,
    )

    transform = build_transform(input_size=input_size)
    images, target_aspect_ratio = dynamic_preprocess_h2ovl(
        image,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
        target_ratios=target_ratios,
    )

    pixel_values = torch.stack([transform(image) for image in images])
    return pixel_values, target_aspect_ratio


# refactored to use the _preprocess_image function
def image_to_pixel_values_h2ovl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
    use_msac: bool,
) -> torch.Tensor:
    # when MSAC is turned on, we need to process the image twice
    if use_msac:
        # first pass
        pixel_values1, aspect_ratio1 = _preprocess_image(
            image,
            input_size=input_size,
            min_num=1,
            max_num=max_num,
            use_thumbnail=True,
            prior_aspect_ratio=None,
        )
        # second pass
        pixel_values2, _ = _preprocess_image(
            image,
            input_size=input_size,
            min_num=3,
            max_num=max_num,
            use_thumbnail=True,
            prior_aspect_ratio=aspect_ratio1,
        )
        # combine pixel values
        pixel_values = torch.cat(
            [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0
        )

    else:
        pixel_values, _ = _preprocess_image(
            image,
            input_size=input_size,
            min_num=min_num,
            max_num=max_num,
            use_thumbnail=use_thumbnail,
            prior_aspect_ratio=None,
        )

    return pixel_values


216
class H2OVLImageProcessor(InternVLImageProcessor):
217
218
    def __init__(
        self,
219
220
221
222
223
224
        image_size: int,
        min_dynamic_patch: int,
        max_dynamic_patch: int,
        dynamic_image_size: bool,
        use_thumbnail: bool,
        use_msac: bool,
225
226
    ) -> None:
        super().__init__(
227
            image_size=image_size,
228
229
230
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
231
            use_thumbnail=use_thumbnail,
232
233
234
235
236
237
238
239
240
241
242
243
        )

        self.use_msac = use_msac

    def resolve_min_max_num(
        self,
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
    ) -> tuple[int, int]:
244
245
246
247
248
249
250
251
        if min_dynamic_patch is None:
            min_dynamic_patch = self.min_dynamic_patch
        if max_dynamic_patch is None:
            max_dynamic_patch = self.max_dynamic_patch
        if dynamic_image_size is None:
            dynamic_image_size = self.dynamic_image_size
        if use_thumbnail is None:
            use_thumbnail = self.use_thumbnail
252
253
254
255
256
257
258
259

        return resolve_h2ovl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
    ) -> list[torch.Tensor]:
        use_msac = self.use_msac if len(images) == 1 else False

        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            image_to_pixel_values_h2ovl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
                use_msac=use_msac,
            )
            for image in images
        ]


class H2OVLProcessor(InternVLProcessor):
    def __init__(
        self,
        image_processor: H2OVLImageProcessor,
        tokenizer: HfTokenizer,
        *,
        image_seq_length: int,
        start_image_token: str = "<img>",
        end_image_token: str = "</img>",
        ctx_image_token: str = "<IMG_CONTEXT>",
    ) -> None:
        super().__init__(
            image_processor=image_processor,
            tokenizer=tokenizer,
            image_seq_length=image_seq_length,
            start_image_token=start_image_token,
            end_image_token=end_image_token,
            ctx_image_token=ctx_image_token,
        )

        self.image_processor: H2OVLImageProcessor

311
312
313
314
315
316
317
318
319
320
    def resolve_target_ratios(
        self,
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
        prior_aspect_ratio: tuple[int, int] | None = None,
        override_min_num: int | None = None,
    ) -> list[tuple[int, int]]:
321
        min_num, max_num = self.image_processor.resolve_min_max_num(
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
        if override_min_num is not None:
            min_num = override_min_num

        return get_h2ovl_target_ratios(
            min_num,
            max_num,
            prior_aspect_ratio=prior_aspect_ratio,
        )

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        use_msac: bool | None = None,
    ) -> int:
343
344
        image_processor = self.image_processor
        use_msac = image_processor.use_msac if use_msac is None else use_msac
345

346
        use_thumbnail = image_processor.use_thumbnail
347
348
349
350
351
352
353
354
355

        if use_msac:
            target_ratios_1 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
                override_min_num=1,
            )
            num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
356
                image_size=image_processor.image_size,
357
358
359
360
361
362
363
364
365
366
367
368
                target_ratios=target_ratios_1,
                use_thumbnail=True,
            )

            target_ratios_2 = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
                prior_aspect_ratio=aspect_ratio_1,
                override_min_num=3,
            )
            num_patches_2, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
369
                image_size=image_processor.image_size,
370
371
372
373
374
375
376
377
378
379
380
381
                target_ratios=target_ratios_2,
                use_thumbnail=True,
            )

            num_patches = num_patches_1 + num_patches_2 - 1
        else:
            target_ratios = self.resolve_target_ratios(
                use_thumbnail=False,  # Applied in calculate_targets
            )
            num_patches, _, _, _ = calculate_h2ovl_targets(
                orig_width=image_width,
                orig_height=image_height,
382
                image_size=image_processor.image_size,
383
384
385
386
                target_ratios=target_ratios,
                use_thumbnail=use_thumbnail,
            )

387
        return num_patches * self.image_seq_length