"vllm/vscode:/vscode.git/clone" did not exist on "77f62613f9eb0963dffca1f58d22f718505e80c7"
interns1.py 29.8 KB
Newer Older
Lyu Han's avatar
Lyu Han committed
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# --------------------------------------------------------
# InternS1
# Copyright (c) 2025 Shanghai AI Lab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
10
from typing import Annotated, Literal, Optional, Union
Lyu Han's avatar
Lyu Han committed
11

12
import regex as re
Lyu Han's avatar
Lyu Han committed
13
14
import torch
import torch.nn as nn
15
from transformers import BatchFeature, InternVLProcessor, PretrainedConfig
Lyu Han's avatar
Lyu Han committed
16
17
from transformers.activations import ACT2FN
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import (
18
19
    GotOcr2ImageProcessorFast,
)
20
from transformers.models.internvl.video_processing_internvl import (
21
22
    InternVLVideoProcessor,
)
Lyu Han's avatar
Lyu Han committed
23
24

from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
Lyu Han's avatar
Lyu Han committed
26
27
28
29
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
Lyu Han's avatar
Lyu Han committed
48
49
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
50
from vllm.transformers_utils.processor import cached_video_processor_from_config
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Lyu Han's avatar
Lyu Han committed
52

53
54
55
56
57
58
59
60
61
62
63
64
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Lyu Han's avatar
Lyu Han committed
65
66
67
68
69


class InternS1MultiModalProjector(nn.Module):
    def __init__(self, config):
        super().__init__()
70
71
72
        self.layer_norm = nn.LayerNorm(
            config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2
        )
Lyu Han's avatar
Lyu Han committed
73
        self.linear_1 = nn.Linear(
74
75
76
            config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2,
            config.text_config.hidden_size,
        )
Lyu Han's avatar
Lyu Han committed
77
        self.act = ACT2FN[config.projector_hidden_act]
78
79
80
        self.linear_2 = nn.Linear(
            config.text_config.hidden_size, config.text_config.hidden_size
        )
Lyu Han's avatar
Lyu Han committed
81
82
83
84
85
86
87
88
89

    def forward(self, image_features):
        hidden_states = self.layer_norm(image_features)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


90
class InternS1ImagePixelInputs(TensorSchema):
Lyu Han's avatar
Lyu Han committed
91
    """
92
93
94
95
96
97
    Dimensions:
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bn: Batch size * number of images
Lyu Han's avatar
Lyu Han committed
98
    """
99

100
101
102
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Lyu Han's avatar
Lyu Han committed
103
104


105
class InternS1ImageEmbeddingInputs(TensorSchema):
Lyu Han's avatar
Lyu Han committed
106
    """
107
108
109
110
    Dimensions:
        - ni: Number of images
        - tifs: Total image feature size
        - hs: Hidden size (must match language model backbone)
Lyu Han's avatar
Lyu Han committed
111
    """
112

113
    type: Literal["image_embeds"] = "image_embeds"
114
115
116
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]], TensorShape("ni", "tifs", "hs")
    ]
Lyu Han's avatar
Lyu Han committed
117
118


119
InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs]
Lyu Han's avatar
Lyu Han committed
120
121


122
class InternS1VideoPixelInputs(TensorSchema):
Lyu Han's avatar
Lyu Han committed
123
    """
124
125
126
127
128
129
    Dimensions:
        - bnv: Batch size * number of videos * number of frames
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
Lyu Han's avatar
Lyu Han committed
130
    """
131

132
133
134
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
    pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
Lyu Han's avatar
Lyu Han committed
135
136


137
class InternS1VideoEmbeddingInputs(TensorSchema):
Lyu Han's avatar
Lyu Han committed
138
    """
139
140
141
142
    Dimensions:
        - nv: Number of videos
        - tvfs: Total video feature size
        - hs: Hidden size (must match language model backbone)
Lyu Han's avatar
Lyu Han committed
143
    """
144

145
    type: Literal["video_embeds"] = "video_embeds"
146
147
148
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]], TensorShape("nv", "tvfs", "hs")
    ]
Lyu Han's avatar
Lyu Han committed
149
150


151
InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoEmbeddingInputs]
Lyu Han's avatar
Lyu Han committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172


def resolve_interns1_min_max_num(
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch


def get_interns1_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
173
174
175
176
177
178
179
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
Lyu Han's avatar
Lyu Han committed
180
181
182
183
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


class InternS1ProcessingInfo(BaseProcessingInfo):
184
    """ProcessingInfo for InternS1-style models."""
Lyu Han's avatar
Lyu Han committed
185
186

    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
187
188
        hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
        hf_processor.video_processor = cached_video_processor_from_config(
189
190
            self.ctx.model_config, processor_cls=InternVLVideoProcessor, **kwargs
        )
191
        return hf_processor
Lyu Han's avatar
Lyu Han committed
192
193

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
194
        return {"image": None, "video": None}
Lyu Han's avatar
Lyu Han committed
195
196
197
198
199
200

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
201
        processor: Optional["GotOcr2ImageProcessorFast"] = None,
Lyu Han's avatar
Lyu Han committed
202
203
204
205
206
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor().image_processor

        if not isinstance(processor, GotOcr2ImageProcessorFast):
207
208
209
            raise ValueError(
                f"GotOcr2ImageProcessorFast is expected but got {type(processor)}"
            )
210
        num_image_patches = processor.get_number_of_image_patches(
211
212
213
            image_height, image_width, images_kwargs=dict()
        )
        num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches
Lyu Han's avatar
Lyu Han committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return num_image_tokens

    def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None):
        image_processor = self.get_hf_processor().image_processor
        min_dynamic_patch = image_processor.min_patches
        max_dynamic_patch = image_processor.max_patches
        # HF format's InternVL processor uses `crop_to_patches` which is
        # equivalent to `use_thumbnail` in original format.
        use_thumbnail = image_processor.crop_to_patches
        dynamic_image_size = True
        min_num, max_num = resolve_interns1_min_max_num(
            min_dynamic_patch,
            max_dynamic_patch,
            dynamic_image_size,
228
229
            use_thumbnail=use_thumbnail,
        )
Lyu Han's avatar
Lyu Han committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

        return get_interns1_target_ratios(min_num, max_num)

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        hf_config = self.ctx.get_hf_config()
        base_height, base_width = hf_config.vision_config.image_size
        target_ratios = self.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_width * wr, base_height * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor.image_processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
251
                largest_feature_pinpoint = ImageSize(width=width, height=height)
Lyu Han's avatar
Lyu Han committed
252

253
254
255
        assert not (largest_feature_size == 0 or largest_feature_pinpoint is None), (
            "Cannot have a largest feature size of 0!"
        )
Lyu Han's avatar
Lyu Han committed
256
257
258
259
260
261
262
263
264
265
266
267
268

        return largest_feature_pinpoint

    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor.image_processor,
        )

269
270
271
272
273
274
275
276
277
278
279
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()

        max_image_tokens = self.get_max_image_tokens() * max_images
280
        max_total_frames = (seq_len - max_image_tokens) // processor.image_seq_length
281
282
283
284
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)

Lyu Han's avatar
Lyu Han committed
285

286
class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]):
287
    """DummyInputsBuilder for InternS1-style models."""
Lyu Han's avatar
Lyu Han committed
288
289
290

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
291
        num_videos = mm_counts.get("video", 0)
Lyu Han's avatar
Lyu Han committed
292
        image_token = self.info.get_hf_processor().image_token
293
        video_token = self.info.get_hf_processor().video_token
Lyu Han's avatar
Lyu Han committed
294

295
        return image_token * num_images + video_token * num_videos
Lyu Han's avatar
Lyu Han committed
296
297
298
299
300

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
301
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
Lyu Han's avatar
Lyu Han committed
302
    ) -> MultiModalDataDict:
303
304
305
306
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
Lyu Han's avatar
Lyu Han committed
307
        num_images = mm_counts.get("image", 0)
308
309
310
311
        num_videos = mm_counts.get("video", 0)

        config = self.info.get_hf_config()
        image_size_h, image_size_w = config.vision_config.image_size
Lyu Han's avatar
Lyu Han committed
312

313
314
315
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

Lyu Han's avatar
Lyu Han committed
316
        return {
317
318
319
320
321
322
323
324
325
326
327
328
329
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
                width=image_size_w,
                height=image_size_h,
                num_frames=target_num_frames,
                num_videos=num_videos,
                overrides=video_overrides,
            ),
Lyu Han's avatar
Lyu Han committed
330
331
332
        }


333
334
class InternS1MultiModalProcessor(BaseMultiModalProcessor[InternS1ProcessingInfo]):
    """Basic image-only MultiModalProcessor for InternS1-style models."""
Lyu Han's avatar
Lyu Han committed
335
336
337
338
339
340
341

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
342
    ) -> BatchFeature:
343
344
345
346
347
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        images = mm_data.pop("images", [])
        assert isinstance(videos, list)
        assert isinstance(images, list)
Lyu Han's avatar
Lyu Han committed
348
349

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
350
        tokenizer = hf_processor.tokenizer
351
352
353
        video_token_id = tokenizer.encode(
            hf_processor.video_token, add_special_tokens=False
        )
354
355
356
        assert len(video_token_id) == 1
        video_token_id = video_token_id[0]

357
358
        prompt = re.sub(hf_processor.image_token, "<image_placeholder>", prompt)
        prompt = re.sub(hf_processor.video_token, "<video_placeholder>", prompt)
359
360
361
362
363
364
365
366
367
368
369

        image_outputs = {}
        if images:
            image_pixel_values = []
            for image in images:
                processed_outputs = super()._call_hf_processor(
                    prompt=hf_processor.image_token,
                    mm_data={"images": image},
                    mm_kwargs=mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
370
                image_pixel_values.append(processed_outputs.pop("pixel_values"))
371
372
373

                input_ids = processed_outputs.pop("input_ids")
                image_placeholder = tokenizer.batch_decode(input_ids)[0]
374
                prompt = prompt.replace("<image_placeholder>", image_placeholder, 1)
375
376

            num_patches = [len(item) for item in image_pixel_values]
377
            image_outputs = {
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
                "pixel_values": torch.concat(image_pixel_values),
                "image_num_patches": torch.tensor(num_patches),
                "image_token_id": torch.tensor(hf_processor.image_token_id),
            }

        video_outputs = {}
        if videos:
            video_pixel_values = []
            for video in videos:
                processed_outputs = super()._call_hf_processor(
                    prompt=hf_processor.video_token,
                    mm_data={"videos": video},
                    mm_kwargs=mm_kwargs,
                    tok_kwargs=tok_kwargs,
                )
393
                video_pixel_values.append(processed_outputs.pop("pixel_values"))
394
395

                input_ids = processed_outputs.pop("input_ids")
396
                input_ids[input_ids == hf_processor.image_token_id] = video_token_id
397
398

                video_placeholder = tokenizer.batch_decode(input_ids)[0]
399
                prompt = prompt.replace("<video_placeholder>", video_placeholder, 1)
400
401

            num_frames = [len(item) for item in video_pixel_values]
402
            video_outputs = {
403
404
405
406
407
                "pixel_values_videos": torch.concat(video_pixel_values),
                "video_num_patches": torch.tensor(num_frames),
                "video_token_id": torch.tensor(video_token_id),
            }

408
409
        prompt = re.sub("<image_placeholder>", hf_processor.image_token, prompt)
        prompt = re.sub("<video_placeholder>", hf_processor.video_token, prompt)
410
411
        text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")

412
        return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
Lyu Han's avatar
Lyu Han committed
413
414
415

    def _get_mm_fields_config(
        self,
416
        hf_inputs: BatchFeature,
Lyu Han's avatar
Lyu Han committed
417
418
419
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
420
        video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
Lyu Han's avatar
Lyu Han committed
421
        num_images = len(image_num_patches)
422
        num_videos = len(video_num_patches)
Lyu Han's avatar
Lyu Han committed
423
424
425

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
426
427
                "image", image_num_patches
            ),
Lyu Han's avatar
Lyu Han committed
428
429
430
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
431
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
432
433
                "video", video_num_patches
            ),
434
435
            video_num_patches=MultiModalFieldConfig.batched("video"),
            video_token_id=MultiModalFieldConfig.shared("video", num_videos),
Lyu Han's avatar
Lyu Han committed
436
437
438
439
440
441
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
442
        out_mm_kwargs: MultiModalKwargsItems,
Lyu Han's avatar
Lyu Han committed
443
444
445
446
447
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        img_context_token = hf_processor.image_token
        start_image_token = hf_processor.start_image_token
        end_image_token = hf_processor.end_image_token
448
449
        video_token = hf_processor.video_token

450
451
452
        out_mm_data = out_mm_kwargs.get_data()
        if "video_num_patches" in out_mm_data:
            video_num_patches = out_mm_data["video_num_patches"]
453
454
455
456
457
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

458
459
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
460
461
462
463
464
465
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        else:
            image_num_patches = []

        def get_replacement_interns1_image(item_idx: int):
Lyu Han's avatar
Lyu Han committed
466
            images = mm_items.get_items(
467
468
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
Lyu Han's avatar
Lyu Han committed
469
470
471
472

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
473
474
                num_patches = image_num_patches[item_idx]
                feature_size = num_patches * hf_processor.image_seq_length
Lyu Han's avatar
Lyu Han committed
475
476
477

            repl_features = img_context_token * feature_size
            repl_full = start_image_token + repl_features + end_image_token
478
            return PromptUpdateDetails.select_text(repl_full, img_context_token)
Lyu Han's avatar
Lyu Han committed
479

480
481
482
        def get_replacement_interns1_video(item_idx: int):
            num_patches = video_num_patches[item_idx]
            repl_features = video_token * hf_processor.image_seq_length
483
            repl_features_with_sep = start_image_token + repl_features + end_image_token
484
            # num_patches is equal to num_frames
485
486
487
            repl_full = "\n".join(
                [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
            )
488
489
490

            return PromptUpdateDetails.select_text(repl_full, video_token)

Lyu Han's avatar
Lyu Han committed
491
492
493
494
        return [
            PromptReplacement(
                modality="image",
                target=img_context_token,
495
496
497
498
499
500
501
                replacement=get_replacement_interns1_image,
            ),
            PromptReplacement(
                modality="video",
                target=video_token,
                replacement=get_replacement_interns1_video,
            ),
Lyu Han's avatar
Lyu Han committed
502
503
504
505
506
507
        ]


@MULTIMODAL_REGISTRY.register_processor(
    InternS1MultiModalProcessor,
    info=InternS1ProcessingInfo,
508
509
510
511
512
    dummy_inputs=InternS1DummyInputsBuilder,
)
class InternS1ForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
513
    merge_by_field_config = True
Lyu Han's avatar
Lyu Han committed
514
515
516
517
518
519
520
521

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
522
523
        }
    )
Lyu Han's avatar
Lyu Han committed
524
525
526

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
co63oc's avatar
co63oc committed
527
        # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator
Lyu Han's avatar
Lyu Han committed
528
529
        # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116
        if modality.startswith("image"):
530
            return "<IMG_CONTEXT>"
Lyu Han's avatar
Lyu Han committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        image_size = config.vision_config.image_size[0]
        patch_size = config.vision_config.patch_size[0]
        self.patch_size = patch_size
        self.num_image_token = int(
549
550
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
Lyu Han's avatar
Lyu Han committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        self.downsample_ratio = config.downsample_ratio

        self.llm_arch_name = config.text_config.architectures[0]
        self.vision_tower = self._init_vision_model(
            config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_tower"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.multi_modal_projector = self._init_mlp1(config)

        self.img_context_token_id = None
        self.video_context_token_id = None

        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
573
574
            self.language_model.make_empty_intermediate_tensors
        )
Lyu Han's avatar
Lyu Han committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        prefix: str,
    ):
        num_hidden_layers = config.vision_config.num_hidden_layers
        return InternS1VisionModel(
            config.vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers,
            prefix=prefix,
        )

591
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
Lyu Han's avatar
Lyu Han committed
592
593
594
595
596
597
598
599
        return InternS1MultiModalProjector(config)

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
600
601
602
603
604
605
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
Lyu Han's avatar
Lyu Han committed
606
607
608
609
610
611
612
        x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
        vit_embeds = self.vision_tower(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

613
        h = w = int(vit_embeds.shape[1] ** 0.5)
Lyu Han's avatar
Lyu Han committed
614
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
615
616
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
Lyu Han's avatar
Lyu Han committed
617
618
619
620
621

        vit_embeds = self.multi_modal_projector(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
622
623
        self, **kwargs: object
    ) -> Optional[InternS1ImageInputs]:
Lyu Han's avatar
Lyu Han committed
624
625
626
627
628
629
630
631
632
633
        pixel_values = kwargs.pop("pixel_values", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return InternS1ImageEmbeddingInputs(
                type="image_embeds",
634
                data=image_embeds,
Lyu Han's avatar
Lyu Han committed
635
636
637
638
639
640
641
            )

        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()

        if pixel_values is not None:
642
            h, w = self.config.vision_config.image_size
Lyu Han's avatar
Lyu Han committed
643
644
            return InternS1ImagePixelInputs(
                type="pixel_values",
645
                pixel_values=pixel_values,
Lyu Han's avatar
Lyu Han committed
646
                num_patches=image_num_patches,
647
648
649
650
                resolve_bindings={
                    "h": h,
                    "w": w,
                },
Lyu Han's avatar
Lyu Han committed
651
652
653
654
655
            )

        raise AssertionError("This line should be unreachable.")

    def _parse_and_validate_video_input(
656
657
        self, **kwargs: object
    ) -> Optional[InternS1VideoInputs]:
658
        pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
Lyu Han's avatar
Lyu Han committed
659
660
661
662
663
664
665
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("video_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
666
            return InternS1VideoEmbeddingInputs(
Lyu Han's avatar
Lyu Han committed
667
                type="video_embeds",
668
                data=video_embeds,
Lyu Han's avatar
Lyu Han committed
669
670
671
672
673
674
675
            )

        video_token_id = kwargs["video_token_id"]
        assert isinstance(video_token_id, torch.Tensor)
        self.video_context_token_id = video_token_id.flatten().unique().item()

        if pixel_values_flat_video is not None:
676
            h, w = self.config.vision_config.image_size
Lyu Han's avatar
Lyu Han committed
677
678
679
            return InternS1VideoPixelInputs(
                type="pixel_values_videos",
                num_patches=video_num_patches,
680
681
682
683
684
                pixel_values=pixel_values_flat_video,
                resolve_bindings={
                    "h": h,
                    "w": w,
                },
Lyu Han's avatar
Lyu Han committed
685
686
687
688
            )

        raise AssertionError("This line should be unreachable.")

689
    def _process_vision_input(
Lyu Han's avatar
Lyu Han committed
690
        self,
691
        image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
Lyu Han's avatar
Lyu Han committed
692
    ) -> tuple[torch.Tensor, ...]:
693
694
695
696
        if (
            image_input["type"] == "image_embeds"
            or image_input["type"] == "video_embeds"
        ):
Lyu Han's avatar
Lyu Han committed
697
698
699
700
701
702
703
704
705
706
            return image_input["data"]

        assert self.vision_tower is not None

        image_embeds = self.extract_feature(image_input["pixel_values"])

        num_patches = image_input["num_patches"]

        # Only one image in the current batch
        if len(num_patches) == 1:
707
            return (image_embeds.view(-1, self.config.text_config.hidden_size),)
Lyu Han's avatar
Lyu Han committed
708
709
710
711

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
712
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
Lyu Han's avatar
Lyu Han committed
713
714
715
716
717
718
719
720
721
722
723
        image_feature_sizes = [
            num_patches * feature_size for num_patches in num_patches
        ]
        return image_embeds.split(image_feature_sizes)

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
724
725
726
727
728
729
730
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_videos",) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
Lyu Han's avatar
Lyu Han committed
731
732
733
734
735
736
737
738
739

        return modalities

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        self.visual_token_mask = None

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

740
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
Lyu Han's avatar
Lyu Han committed
741
742
743
744
745
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
746
        # tensor corresponding to a multimodal data item (image or video).
Lyu Han's avatar
Lyu Han committed
747
748
749
750
751
752
753
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
754
                vision_embeddings = self._process_vision_input(image_input)
Lyu Han's avatar
Lyu Han committed
755
756
757
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
758
                video_embeddings = self._process_vision_input(video_input)
Lyu Han's avatar
Lyu Han committed
759
760
761
762
763
764
765
766
                multimodal_embeddings += video_embeddings

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
767
768
769
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
Lyu Han's avatar
Lyu Han committed
770
    ) -> torch.Tensor:
771
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
Lyu Han's avatar
Lyu Han committed
772
            self._set_visual_token_mask(input_ids)
773
774
775
776
777
778
779
780
781
782
783

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
Lyu Han's avatar
Lyu Han committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> IntermediateTensors:
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        hidden_states = self.language_model.model(**forward_kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
811
        return self.language_model.compute_logits(hidden_states)
Lyu Han's avatar
Lyu Han committed
812

813
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Lyu Han's avatar
Lyu Han committed
814
815
816
817
818
819
820
821
822
823
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
824
825
            tower_model="vision_tower",
        )