internvl.py 19.2 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
import itertools
8
9
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
10
11
12
13
14
15
16
17
18
19
20

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
21
from vllm.model_executor.layers.sampler import SamplerOutput
22
23
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.base import MultiModalInputs
27
from vllm.multimodal.utils import cached_get_tokenizer
28
from vllm.sequence import IntermediateTensors
29
30
31

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                   get_clip_num_patches)
32
from .interfaces import SupportsMultiModal
33
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
34
                    merge_multimodal_embeddings)
35
36
37
38
39
40
41
42
43
44
45

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
46
    data: torch.Tensor
47
    """
48
49
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
50
51
52
    """


53
54
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
55
56
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
57
58
59
60
61
62
63
64
65

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


97
98
99
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
                         max_num: int,
                         image_size: int) -> Tuple[int, int, int]:
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    return blocks, target_width, target_height


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
121
122
123
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
                       image_size: int,
                       use_thumbnail: int) -> List[Image.Image]:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    orig_width, orig_height = image.size

    blocks, target_width, target_height = calculate_num_blocks(
        orig_width, orig_height, min_num, max_num, image_size)
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
147
148
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
                          max_num: int, use_thumbnail: bool) -> torch.Tensor:
149
150
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
151
152
                                min_num=min_num,
                                max_num=max_num,
153
                                image_size=input_size,
154
                                use_thumbnail=use_thumbnail)
155
156
157
158
159
160
161
162
163
164
165
166
167
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def get_internvl_num_patches(image_size: int, patch_size: int,
                             downsample_ratio: float):
    return int(
        get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
        (downsample_ratio**2))


def get_max_internvl_image_tokens(ctx: InputContext):
168
    hf_config = ctx.get_hf_config()
169
    vision_config = hf_config.vision_config
170
171
172
173
174
175
176

    use_thumbnail = hf_config.use_thumbnail
    max_dynamic_patch = hf_config.max_dynamic_patch
    if use_thumbnail:
        max_dynamic_patch += 1
    downsample_ratio = hf_config.downsample_ratio

177
178
179
180
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
    num_patches = get_internvl_num_patches(image_size, patch_size,
                                           downsample_ratio)
181
    return num_patches * max_dynamic_patch
182
183
184
185
186
187
188
189


def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
190
    hf_config = ctx.get_hf_config()
191
192
    vision_config = hf_config.vision_config

193
194
195
196
197
198
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
    downsample_ratio = hf_config.downsample_ratio
    num_patches = get_internvl_num_patches(image_size, patch_size,
                                           downsample_ratio)

199
200
201
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size
202
203
204
205
206
207
208
        min_num = hf_config.min_dynamic_patch
        max_num = hf_config.max_dynamic_patch
        num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
                                                max_num, image_size)
        # add thumbnail image if num_blocks > 1
        if hf_config.use_thumbnail and num_blocks > 1:
            num_blocks += 1
209
210
        image_feature_size = num_blocks * num_patches

211
    elif isinstance(image_data, torch.Tensor):
212
        image_feature_size = image_data.shape[0]
213
214
215
216
217
218
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)

219
    prompt = llm_inputs.get("prompt")
220
221
222
    prompt_token_ids = llm_inputs["prompt_token_ids"]
    if prompt is None:
        prompt = tokenizer.decode(prompt_token_ids)
223
    image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
224
225
226
227
228
229
230
231
232
    new_prompt = prompt.replace('<image>', image_prompt, 1)
    new_prompt_token_ids = tokenizer.encode(new_prompt)

    return LLMInputs(prompt=prompt,
                     prompt_token_ids=new_prompt_token_ids,
                     multi_modal_data=multi_modal_data)


def input_mapper_for_internvl(ctx: InputContext, data: object):
233
    hf_config = ctx.get_hf_config()
234
235
236
237
238
239

    use_thumbnail = hf_config.use_thumbnail
    min_num = hf_config.min_dynamic_patch
    max_num = hf_config.max_dynamic_patch
    image_size = hf_config.vision_config.image_size

240
    if isinstance(data, Image.Image):
241
242
243
244
245
        data = image_to_pixel_values(data,
                                     image_size,
                                     min_num,
                                     max_num,
                                     use_thumbnail=use_thumbnail)
246
247
        # Add an N dimension for number of images per prompt (currently 1).
        data = data.unsqueeze(0)
248
249
250
251
252
253
254
255
256
257
258
259
260
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)
    image_token_id = tokenizer.encode(IMG_CONTEXT,
                                      add_special_tokens=False,
                                      return_tensors="pt")[0]

    return MultiModalInputs({
        "pixel_values": data,
        "image_token_id": image_token_id
    })


261
262
263
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
                            mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
264
265
266

    image_feature_size = get_max_internvl_image_tokens(ctx)
    model_config = ctx.model_config
267
    hf_config = ctx.get_hf_config()
268
269
270
271
272
273
274
    vision_config = hf_config.vision_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)

    seq_data = dummy_seq_data_for_clip(
        vision_config,
        seq_len,
275
        num_images,
276
277
278
279
        image_token_id=tokenizer.encode(IMG_CONTEXT,
                                        add_special_tokens=False)[0],
        image_feature_size_override=image_feature_size,
    )
280
281
282
283
284
285
286

    image_size = vision_config.image_size
    min_num = hf_config.min_dynamic_patch
    max_num = hf_config.max_dynamic_patch
    max_image_width = max_num * image_size
    max_image_height = min_num * image_size

287
288
    mm_data = dummy_image_for_clip(
        vision_config,
289
        num_images,
290
291
        image_width_override=max_image_width,
        image_height_override=max_image_height,
292
293
294
295
296
297
298
299
300
    )

    return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
301
class InternVLChatModel(nn.Module, SupportsMultiModal):
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

    def __init__(self,
                 config: PretrainedConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()

        self.config = config
        self.multimodal_config = multimodal_config

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.select_layer = config.select_layer
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

        vision_feature_layer = self.select_layer
        if vision_feature_layer < 0:
            num_hidden_layers = config.vision_config.num_hidden_layers \
                + vision_feature_layer + 1
        else:
            num_hidden_layers = vision_feature_layer + 1
        self.vision_model = InternVisionModel(
            config.vision_config, num_hidden_layers_override=num_hidden_layers)

331
332
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        self.mlp1 = nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size), nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size))

        self.img_context_token_id = None

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

359
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
360
361
362
363
364
365
366
367
368
369
370
371
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

372
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
373
374
375
376
377
378
379
380

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
381
                expected_expr = str(expected_dims)
382
                raise ValueError(
383
384
385
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
386
387
388
389
390
391
392

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
393
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
394
395
        pixel_values = kwargs.pop("pixel_values", None)
        image_token_id = kwargs.pop("image_token_id", None)
396
        image_embeds = kwargs.pop("image_embeds", None)
397

398
        if pixel_values is None and image_embeds is None:
399
400
            return None

401
402
403
404
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
405

406
407
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
408
                data=flatten_bn(image_embeds),
409
410
            )

411
412
        self.img_context_token_id = image_token_id[0]

413
414
415
416
417
418
419
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            return InternVLImagePixelInputs(
                type="pixel_values",
420
421
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True).flatten(0, 1)),
422
423
424
425
426
427
428
429
430
431
432
433
434
435
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
        image_embeds = self.extract_feature(image_input["data"])
436

437
        return image_embeds
438
439
440
441
442
443
444
445
446
447
448
449
450
451

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> SamplerOutput:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is not None:
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
452
            vision_embeddings = self._process_image_input(image_input)
453
454
455
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
                self.img_context_token_id)
456
457
458
459
460
461
462
463
464
465
466
467
            input_ids = None
        else:
            inputs_embeds = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
        return hidden_states

468
469
470
471
472
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
473
474
475
476
477
478
479
480
481
482
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

483
484
485
486
487
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # prepare weight iterators for components
        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

        # load vision encoder
488
        vit_weights = filter_weights(vit_weights, "vision_model")
489
490
491
        self.vision_model.load_weights(vit_weights)

        # load mlp projector
492
        mlp_weights = filter_weights(mlp_weights, "mlp1")
493
494
495
496
497
498
499
500
        mlp_params_dict = dict(self.mlp1.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
501
        llm_weights = filter_weights(llm_weights, "language_model")
502
        self.language_model.load_weights(llm_weights)