internvl.py 22.1 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
import re
8
9
10
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
11
12
13
14
15
16
17
18
19
20
21

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
23
24
25
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.base import MultiModalInputs
28
from vllm.multimodal.utils import cached_get_tokenizer
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils import is_list_of
31
32
33

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                   get_clip_num_patches)
34
from .interfaces import SupportsMultiModal, SupportsPP
35
36
from .utils import (flatten_bn, group_weights_with_prefix,
                    init_vllm_registered_model, merge_multimodal_embeddings)
37
38
39
40
41
42
43
44
45
46
47

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
48
    data: torch.Tensor
49
    """
50
51
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
52
53
54
    """


55
56
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
57
58
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
59
60
61
62
63
64
65
66
67

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


99
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
100
101
                         max_num: int, image_size: int,
                         use_thumbnail: bool) -> Tuple[int, int, int]:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
119
120
121
    # add thumbnail image if num_blocks > 1
    if use_thumbnail and blocks > 1:
        blocks += 1
122
123
124
    return blocks, target_width, target_height


125
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
126
127
128
129
130
131
132
133
134
135
136
137
138
                                 max_dynamic_patch: Optional[int] = None):
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    min_num = hf_config.min_dynamic_patch
    image_size = hf_config.vision_config.image_size
    use_thumbnail = hf_config.use_thumbnail
    return partial(calculate_num_blocks,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   image_size=image_size,
                   use_thumbnail=use_thumbnail)


139
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
140
141
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
                       image_size: int,
142
                       use_thumbnail: bool) -> List[Image.Image]:
143
144
    orig_width, orig_height = image.size

145
    # calculate the number of blocks without thumbnail
146
    blocks, target_width, target_height = calculate_num_blocks(
147
148
149
150
151
152
        orig_width,
        orig_height,
        min_num,
        max_num,
        image_size,
        use_thumbnail=False)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
172
173
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
                          max_num: int, use_thumbnail: bool) -> torch.Tensor:
174
175
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
176
177
                                min_num=min_num,
                                max_num=max_num,
178
                                image_size=input_size,
179
                                use_thumbnail=use_thumbnail)
180
181
182
183
184
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


185
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
186
187
188
189
190
191
192
193
194
195
196
197
198
                                  max_dynamic_patch: Optional[int] = None):
    image_size = hf_config.vision_config.image_size
    min_num = hf_config.min_dynamic_patch
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    return partial(image_to_pixel_values,
                   input_size=image_size,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   use_thumbnail=use_thumbnail)


199
def get_internvl_num_patches(hf_config: PretrainedConfig):
200
201
202
203
    vision_config = hf_config.vision_config
    downsample_ratio = hf_config.downsample_ratio
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
204
205
206
207
208
    return int(
        get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
        (downsample_ratio**2))


209
210
211
def get_max_internvl_image_tokens(ctx: InputContext,
                                  *,
                                  max_dynamic_patch: Optional[int] = None):
212
    hf_config = ctx.get_hf_config()
213

214
215
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
216
    use_thumbnail = hf_config.use_thumbnail
217
    if use_thumbnail and max_dynamic_patch > 1:
218
219
        max_dynamic_patch += 1

220
    num_patches = get_internvl_num_patches(hf_config)
221
    return num_patches * max_dynamic_patch
222
223


224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_max_internvl_image_size(ctx: InputContext,
                                *,
                                max_dynamic_patch: Optional[int] = None):
    hf_config = ctx.get_hf_config()
    image_size = hf_config.vision_config.image_size

    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    if use_thumbnail and max_dynamic_patch > 1:
        max_dynamic_patch += 1
    width = image_size * max_dynamic_patch
    height = image_size
    return width, height


def input_processor_for_internvl(ctx: InputContext,
                                 llm_inputs: LLMInputs,
                                 *,
                                 max_dynamic_patch: Optional[int] = None):
244
245
246
247
248
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
249
    hf_config = ctx.get_hf_config()
250

251
    image_data = multi_modal_data["image"]
252
253
254
    num_patches = get_internvl_num_patches(hf_config)
    num_blocks_calculator = calculate_num_blocks_wrapper(
        hf_config, max_dynamic_patch)
255
256
    if isinstance(image_data, Image.Image):
        width, height = image_data.size
257
        num_blocks, _, _ = num_blocks_calculator(width, height)
258
259
260
261
262
        image_feature_size = [num_blocks * num_patches]
    elif is_list_of(image_data, Image.Image):
        image_feature_size = []
        for image in image_data:
            width, height = image.size
263
            num_blocks, _, _ = num_blocks_calculator(width, height)
264
            image_feature_size.append(num_blocks * num_patches)
265
    elif isinstance(image_data, torch.Tensor):
266
        num_images, image_feature_size, hidden_size = image_data.shape
267
268
269
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

270
271
272
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
273

274
    prompt = llm_inputs.get("prompt")
275
276
277
    prompt_token_ids = llm_inputs["prompt_token_ids"]
    if prompt is None:
        prompt = tokenizer.decode(prompt_token_ids)
278
279
280
281
282
283
284
285

    new_prompt = prompt
    image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
    for idx, feature_size in enumerate(image_feature_size, start=1):
        image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
        if not image_idx:
            image_prompt = f"Image-{idx}: {image_prompt}"
        new_prompt = new_prompt.replace('<image>', image_prompt, 1)
286
287
288
289
290
291
292
    new_prompt_token_ids = tokenizer.encode(new_prompt)

    return LLMInputs(prompt=prompt,
                     prompt_token_ids=new_prompt_token_ids,
                     multi_modal_data=multi_modal_data)


293
294
295
296
def input_mapper_for_internvl(ctx: InputContext,
                              data: object,
                              *,
                              max_dynamic_patch: Optional[int] = None):
297
    hf_config = ctx.get_hf_config()
298

299
300
    image_pixel_values_mapper = image_to_pixel_values_wrapper(
        hf_config, max_dynamic_patch)
301
    if isinstance(data, Image.Image):
302
        data = image_pixel_values_mapper(data)
303
304
        # Add an N dimension for number of images per prompt (currently 1).
        data = data.unsqueeze(0)
305
    elif is_list_of(data, Image.Image):
306
        # we can't stack here because the images may have different num_patches
307
        data = [image_pixel_values_mapper(img) for img in data]
308
    model_config = ctx.model_config
309
310
311
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
312
313
314
315
316
317
318
319
320
321
    image_token_id = tokenizer.encode(IMG_CONTEXT,
                                      add_special_tokens=False,
                                      return_tensors="pt")[0]

    return MultiModalInputs({
        "pixel_values": data,
        "image_token_id": image_token_id
    })


322
323
324
325
326
def dummy_data_for_internvl(ctx: InputContext,
                            seq_len: int,
                            mm_counts: Mapping[str, int],
                            *,
                            max_dynamic_patch: Optional[int] = None):
327
    num_images = mm_counts["image"]
328

329
    hf_config = ctx.get_hf_config()
330
331
332
333

    image_feature_size = get_max_internvl_image_tokens(
        ctx, max_dynamic_patch=max_dynamic_patch)
    model_config = ctx.model_config
334
335
336
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
337
338

    seq_data = dummy_seq_data_for_clip(
339
        hf_config.vision_config,
340
        seq_len,
341
        num_images,
342
343
344
345
        image_token_id=tokenizer.encode(IMG_CONTEXT,
                                        add_special_tokens=False)[0],
        image_feature_size_override=image_feature_size,
    )
346

347
348
    max_image_width, max_image_height = get_max_internvl_image_size(
        ctx, max_dynamic_patch=max_dynamic_patch)
349

350
    mm_data = dummy_image_for_clip(
351
        hf_config.vision_config,
352
        num_images,
353
354
        image_width_override=max_image_width,
        image_height_override=max_image_height,
355
356
357
358
359
360
361
362
363
    )

    return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
364
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

    def __init__(self,
                 config: PretrainedConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()

        self.config = config
        self.multimodal_config = multimodal_config

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.select_layer = config.select_layer
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

        vision_feature_layer = self.select_layer
        if vision_feature_layer < 0:
            num_hidden_layers = config.vision_config.num_hidden_layers \
                + vision_feature_layer + 1
        else:
            num_hidden_layers = vision_feature_layer + 1
        self.vision_model = InternVisionModel(
            config.vision_config, num_hidden_layers_override=num_hidden_layers)

394
395
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
396
397
398
399
400
401
402
403
404
405
406

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        self.mlp1 = nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size), nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size))

        self.img_context_token_id = None
407
408
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
409

410
411
    @cached_property
    def sampler(self):
412
        if hasattr(self.language_model, "sampler"):
413
414
415
            return self.language_model.sampler

        return Sampler()
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

431
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
432
433
434
435
436
437
438
439
440
441
442
443
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

444
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
445
446
447
448
449
450
451
452

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
453
                expected_expr = str(expected_dims)
454
                raise ValueError(
455
456
457
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
458
459
460
461
462
463
464

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
465
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
466
467
        pixel_values = kwargs.pop("pixel_values", None)
        image_token_id = kwargs.pop("image_token_id", None)
468
        image_embeds = kwargs.pop("image_embeds", None)
469

470
        if pixel_values is None and image_embeds is None:
471
472
            return None

473
474
475
476
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
477

478
479
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
480
                data=flatten_bn(image_embeds),
481
482
            )

483
484
        self.img_context_token_id = image_token_id[0]

485
486
487
488
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
489
490
            # We need to flatten (B, N, P) to (B*N*P),
            # so we call flatten_bn twice.
491
492
            return InternVLImagePixelInputs(
                type="pixel_values",
493
                data=self._validate_pixel_values(
494
                    flatten_bn(flatten_bn(pixel_values), concat=True)),
495
496
497
498
499
500
501
502
503
504
505
506
507
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
        image_embeds = self.extract_feature(image_input["data"])
508

509
        return image_embeds
510
511
512
513
514
515
516
517
518

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
519
520
    ) -> Union[SamplerOutput, IntermediateTensors]:
        if intermediate_tensors is not None:
521
522
            input_ids = None
            inputs_embeds = None
523
524
525
526
527
528
529
530
531
532
533
534
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
            if image_input is not None:
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.img_context_token_id)
                input_ids = None
            else:
                inputs_embeds = None
535
536
537
538
539

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
540
                                                  intermediate_tensors,
541
542
543
                                                  inputs_embeds=inputs_embeds)
        return hidden_states

544
545
546
547
548
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
549
550
551
552
553
554
555
556
557
558
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

559
560
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # prepare weight iterators for components
561
        weights_group = group_weights_with_prefix(weights)
562
563

        # load vision encoder
564
        self.vision_model.load_weights(weights_group["vision_model"])
565
566
567

        # load mlp projector
        mlp_params_dict = dict(self.mlp1.named_parameters())
568
        for name, loaded_weight in weights_group["mlp1"]:
569
570
571
572
573
574
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
575
        self.language_model.load_weights(weights_group["language_model"])