skyworkr1v.py 31 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
11
from typing import Annotated, Literal, TypeAlias
12
13
14
15
16

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
17
from transformers import BatchFeature, PretrainedConfig, TensorType
18
19

from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
22
23
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
24
25
26
27
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.image import convert_image_mode
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
49
from vllm.sequence import IntermediateTensors
50
from vllm.tokenizers import TokenizerLike
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52
53

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
54
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
55

56
57
58
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<IMG_CONTEXT>"
59
60
61
62
63

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


64
class SkyworkR1VImagePixelInputs(TensorSchema):
65
    """
66
67
68
69
70
71
    Dimensions:
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bn: Batch size * number of images
72
    """
73

74
    type: Literal["pixel_values"] = "pixel_values"
75

76
77
78
79
    pixel_values_flat: Annotated[
        torch.Tensor,
        TensorShape("bnp", 3, "h", "w"),
    ]
80

81
82
83
84
    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
85
86


87
class SkyworkR1VImageEmbeddingInputs(TensorSchema):
88
    """
89
90
91
    Dimensions:
        - ni: Number of images
        - ifs: Image feature size
92
        - hs: Hidden size (must match the hidden size of language model
93
94
          backbone)
    """
95

96
97
98
    type: Literal["image_embeds"] = "image_embeds"

    data: Annotated[
99
        torch.Tensor | list[torch.Tensor],
100
101
        TensorShape("ni", "ifs", "hs"),
    ]
102
103


104
105
106
SkyworkR1VImageInputs: TypeAlias = (
    SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs
)
107
108
109
110
111


# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def build_transform(input_size: int):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
112
113
114
115
116
117
118
119
120
121
    return T.Compose(
        [
            T.Lambda(lambda img: convert_image_mode(img, "RGB")),
            T.Resize(
                (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
            ),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD),
        ]
    )
122
123
124
125
126
127
128
129
130
131
132


# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
133
    best_ratio_diff = float("inf")
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def resolve_skyworkr1v_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch


def get_skyworkr1v_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
168
169
170
171
172
173
174
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_skyworkr1v_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1

    return blocks, target_width, target_height


def dynamic_preprocess_skyworkr1v(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
    orig_width, orig_height = image.size

    # calculate the number of blocks without thumbnail
    blocks, target_width, target_height = calculate_skyworkr1v_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
231
232
233
234
235
236
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images


# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B
def image_to_pixel_values_skyworkr1v(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_skyworkr1v_target_ratios(min_num, max_num)

    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess_skyworkr1v(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
    return pixel_values


273
class SkyworkR1VProcessor:
274
275
276
277
278
279
280
281
282
283
284
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    The code to insert image tokens is based on:
    https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252
    """

    def __init__(
        self,
        config: PretrainedConfig,
285
        tokenizer: TokenizerLike,
286
        *,
287
288
289
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size

        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)

        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)

        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

        self.num_image_token = int(
312
313
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
314
315
316
317
318
319
320
321
        self.image_size = image_size
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    def image_token_id(self) -> int:
322
        return self.tokenizer.get_vocab()[IMG_CONTEXT]
323
324
325
326

    def get_image_repl(
        self,
        feature_size: int,
327
        num_patches: int | None,
328
    ) -> PromptUpdateDetails[str]:
329
330
331
332
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END

        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
333
334
335
336

    def resolve_min_max_num(
        self,
        *,
337
338
339
340
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
341
    ) -> tuple[int, int]:
342
343
344
345
346
347
348
349
350
351
352
353
        min_dynamic_patch = (
            self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
        )
        max_dynamic_patch = (
            self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
        )
        dynamic_image_size = (
            self.dynamic_image_size
            if dynamic_image_size is None
            else dynamic_image_size
        )
        use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
354
355
356
357
358
359
360
361
362
363
364

        return resolve_skyworkr1v_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )

    def resolve_target_ratios(
        self,
        *,
365
366
367
368
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )

        return get_skyworkr1v_target_ratios(min_num, max_num)

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )

        num_patches, _, _ = calculate_skyworkr1v_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )

        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
402
403
404
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            image_to_pixel_values_skyworkr1v(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
420
421
            )
            for image in images
422
423
424
425
        ]

    def __call__(
        self,
426
427
428
429
430
431
        text: str | list[str] | None = None,
        images: Image.Image | list[Image.Image] | None = None,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        return_tensors: str | TensorType | None = None,
432
    ) -> BatchFeature:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
                min_dynamic_patch=min_dynamic_patch,
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
451
            image_inputs = {
452
453
454
455
                "pixel_values_flat": torch.cat(pixel_values_lst),
                "image_num_patches": torch.tensor(
                    [len(item) for item in pixel_values_lst]
                ),
456
457
458
459
460
461
462
463
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

                image_repl = self.get_image_repl(feature_size, num_patches)

464
                text = [t.replace("<image>", image_repl.full, 1) for t in text]
465
466
467

        text_inputs = self.tokenizer(text)

468
469
470
        combined_outputs = {**text_inputs, **image_inputs}

        return BatchFeature(combined_outputs, tensor_type=return_tensors)
471
472


473
474
475
476
477
478
479
480
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
    def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
        return self.ctx.init_processor(
            SkyworkR1VProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
        )
481

482
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
483
484
485
486
487
488
489
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
490
        processor: SkyworkR1VProcessor | None,
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
517
                largest_feature_pinpoint = ImageSize(width=width, height=height)
518
519
520
521
522
523
524

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


525
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
526
527
528
529
530
531
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
532
533
534
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
535
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
536
    ) -> MultiModalDataDict:
537
        target_width, target_height = self.info.get_image_size_with_most_features()
538
539
        num_images = mm_counts.get("image", 0)

540
541
        image_overrides = mm_options.get("image") if mm_options else None

542
        return {
543
544
545
546
547
548
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
549
550
551
        }


552
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
553
554
555
556
557
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
558
        tok_kwargs: Mapping[str, object],
559
    ) -> BatchFeature:
560
561
562
563
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
564
            tok_kwargs=tok_kwargs,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        )

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)

        return processed_outputs

    def _get_mm_fields_config(
        self,
579
        hf_inputs: BatchFeature,
580
581
582
583
584
585
586
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
        num_images = len(image_num_patches)

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
587
588
                "image", image_num_patches
            ),
589
590
591
592
593
594
595
596
597
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
598
        out_mm_kwargs: MultiModalKwargsItems,
599
600
601
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

602
603
604
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
605
606
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
607
        elif "image_embeds" in out_mm_data:
608
609
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
610
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
611
612
613
614
615
        else:
            image_num_patches = []

        def get_replacement_skyworkr1v(item_idx: int):
            images = mm_items.get_items(
616
617
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

            return hf_processor.get_image_repl(feature_size, num_patches)

        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_skyworkr1v,
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
    SkyworkR1VMultiModalProcessor,
    info=SkyworkR1VProcessingInfo,
647
648
    dummy_inputs=SkyworkR1VDummyInputsBuilder,
)
649
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
650
    @classmethod
651
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
652
653
654
655
656
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
672
673
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
674
675
676
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

677
678
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"
679

680
681
682
683
684
685
686
687
688
689
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(
                config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
            )
690

691
692
693
694
695
696
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
697
698
699
700

        self.img_context_token_id = None
        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
701
702
            self.language_model.make_empty_intermediate_tensors
        )
703

704
705
706
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
707
708
709
710
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
711
712
713
714
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
715
716
717
718
719
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
720
        quant_config: QuantizationConfig | None,
721
722
723
724
725
726
727
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
728
729
730
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
731
732
733
734
735
736
737
738
739
740
741
742
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            return InternVisionPatchModel(config.vision_config)

743
744
745
746
747
748
    def _init_mlp1(
        self,
        config: PretrainedConfig,
        quant_config: QuantizationConfig,
        prefix: str = "",
    ) -> nn.Module:
749
750
751
752
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
753
754
755
756
757
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            ReplicatedLinear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                llm_hidden_size,
                return_bias=False,
758
759
                quant_config=quant_config,
                prefix=f"{prefix}.1",
760
            ),
761
            nn.GELU(),
762
763
764
765
766
767
768
            ReplicatedLinear(
                llm_hidden_size,
                llm_hidden_size,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.3",
            ),
769
770
771
772
773
774
775
776
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
777
778
779
780
781
782
783
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
784
785
786
787
788
789
790
791
792
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

793
        h = w = int(vit_embeds.shape[1] ** 0.5)
794
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
795
796
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
797
798
799
800
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
801
        self, **kwargs: object
802
    ) -> SkyworkR1VImageInputs | None:
803
804
805
806
807
808
809
810
811
812
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return SkyworkR1VImageEmbeddingInputs(
                type="image_embeds",
813
                data=image_embeds,
814
815
816
            )

        image_token_id = kwargs["image_token_id"]
817
818
819
820
821
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
822
823
824
825

        if pixel_values_flat is not None:
            return SkyworkR1VImagePixelInputs(
                type="pixel_values",
826
                pixel_values_flat=pixel_values_flat,
827
                num_patches=image_num_patches,
828
829
830
                resolve_bindings={
                    "h": self.config.vision_config.image_size,
                    "w": self.config.vision_config.image_size,
831
832
                },
            )
833
834
835
836
837
838

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: SkyworkR1VImageInputs,
839
    ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
840
841
842
843
844
845
846
847
848
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_embeds = self.extract_feature(image_input["pixel_values_flat"])

        num_patches = image_input["num_patches"]

        # Only one image in the current batch
        if len(num_patches) == 1:
849
850
851
            return image_embeds.view(-1, self.config.text_config.hidden_size).unsqueeze(
                0
            )
852
853
854
855

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
856
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
857
858
859
860
861
862
863
        image_feature_sizes = [
            num_patches * feature_size for num_patches in num_patches
        ]
        return image_embeds.split(image_feature_sizes)

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        if self.is_mono:
864
865
866
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
867
868
869
        else:
            self.visual_token_mask = None

870
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
871
872
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
873
            return []
874

875
        return self._process_image_input(image_input)
876

877
    def embed_input_ids(
878
879
        self,
        input_ids: torch.Tensor,
880
        multimodal_embeddings: MultiModalEmbeddings | None = None,
881
        *,
882
        is_multimodal: torch.Tensor | None = None,
883
        handle_oov_mm_token: bool = False,
884
    ) -> torch.Tensor:
885
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
886
            self._set_visual_token_mask(input_ids)
887
888
889

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
890
            return super().embed_input_ids(input_ids)
891

892
        return super().embed_input_ids(
893
894
895
896
897
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
898
899
900

    def forward(
        self,
901
        input_ids: torch.Tensor | None,
902
        positions: torch.Tensor,
903
904
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
905
        **kwargs: object,
906
    ) -> IntermediateTensors:
907
908
909
910
911
912
913
914
915
916
917
918
        if intermediate_tensors is not None:
            inputs_embeds = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        # Only required if the model is mono-architecture
        if self.visual_token_mask is not None:
919
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
920
921
922
923
924
925
926
927
            self.visual_token_mask = None

        hidden_states = self.language_model.model(**forward_kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
928
    ) -> torch.Tensor | None:
929
        return self.language_model.compute_logits(hidden_states)
930

931
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
932
        skip_prefixes = [
933
934
935
936
937
938
939
940
941
942
943
944
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
945
946
947
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
        return loader.load_weights(weights)