skyworkr1v.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9

# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Iterable, Mapping
11
from typing import Annotated, Literal, TypeAlias
12
13
14

import torch
import torch.nn as nn
15
from transformers import PretrainedConfig
16
17

from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.inputs import MultiModalDataDict
20
21
22
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
23
24
25
26
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.processing import BaseDummyInputsBuilder
29
from vllm.sequence import IntermediateTensors
30
31
32
33
from vllm.transformers_utils.processors.internvl import (
    InternVLImageProcessor,
    InternVLProcessor,
)
34
from vllm.utils.tensor_schema import TensorSchema, TensorShape
35
36

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
37
38
39
40
41
from .internvl import (
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
)
42
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
43
44


45
class SkyworkR1VImagePixelInputs(TensorSchema):
46
    """
47
48
49
50
51
52
    Dimensions:
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bn: Batch size * number of images
53
    """
54

55
    type: Literal["pixel_values"] = "pixel_values"
56

57
58
59
60
    pixel_values_flat: Annotated[
        torch.Tensor,
        TensorShape("bnp", 3, "h", "w"),
    ]
61

62
63
64
65
    num_patches: Annotated[
        torch.Tensor,
        TensorShape("bn"),
    ]
66
67


68
class SkyworkR1VImageEmbeddingInputs(TensorSchema):
69
    """
70
71
72
    Dimensions:
        - ni: Number of images
        - ifs: Image feature size
73
        - hs: Hidden size (must match the hidden size of language model
74
75
          backbone)
    """
76

77
78
79
    type: Literal["image_embeds"] = "image_embeds"

    data: Annotated[
80
        torch.Tensor | list[torch.Tensor],
81
82
        TensorShape("ni", "ifs", "hs"),
    ]
83
84


85
86
87
SkyworkR1VImageInputs: TypeAlias = (
    SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs
)
88
89


90
class SkyworkR1VProcessingInfo(BaseInternVLProcessingInfo):
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)

        return InternVLImageProcessor(**kwargs)

    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        return InternVLProcessor(
115
            tokenizer=self.get_tokenizer(),
116
117
            image_processor=image_processor,
            image_seq_length=image_seq_length,
118
        )
119
120


121
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
122
123
124
125
126
127
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
128
129
130
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
131
        mm_options: Mapping[str, BaseDummyOptions],
132
    ) -> MultiModalDataDict:
133
        target_width, target_height = self.info.get_image_size_with_most_features()
134
135
        num_images = mm_counts.get("image", 0)

136
        image_overrides = mm_options.get("image")
137

138
        return {
139
140
141
142
143
144
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
145
146
147
148
        }


@MULTIMODAL_REGISTRY.register_processor(
149
    BaseInternVLMultiModalProcessor,
150
    info=SkyworkR1VProcessingInfo,
151
    dummy_inputs=BaseInternVLDummyInputsBuilder,
152
)
153
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
154
    @classmethod
155
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
156
157
158
159
160
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
176
177
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
178
179
180
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

181
182
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"
183

184
185
186
187
188
189
190
191
192
193
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(
                config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
            )
194

195
196
197
198
199
200
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
201
202
203
204

        self.img_context_token_id = None
        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
205
206
            self.language_model.make_empty_intermediate_tensors
        )
207

208
209
210
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
211
212
213
214
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
215
216
217
218
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
219
220
221
222
223
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
224
        quant_config: QuantizationConfig | None,
225
226
227
228
229
230
231
        *,
        is_mono: bool,
        prefix: str,
    ):
        if not is_mono:
            vision_feature_layer = config.select_layer
            if vision_feature_layer < 0:
232
233
234
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
235
236
237
238
239
240
241
242
243
244
245
246
            else:
                num_hidden_layers = vision_feature_layer + 1

            return InternVisionModel(
                config.vision_config,
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
        else:
            return InternVisionPatchModel(config.vision_config)

247
248
249
250
251
252
    def _init_mlp1(
        self,
        config: PretrainedConfig,
        quant_config: QuantizationConfig,
        prefix: str = "",
    ) -> nn.Module:
253
254
255
256
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
257
258
259
260
261
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            ReplicatedLinear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                llm_hidden_size,
                return_bias=False,
262
263
                quant_config=quant_config,
                prefix=f"{prefix}.1",
264
            ),
265
            nn.GELU(),
266
267
268
269
270
271
272
            ReplicatedLinear(
                llm_hidden_size,
                llm_hidden_size,
                return_bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.3",
            ),
273
274
275
276
277
278
279
280
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
281
282
283
284
285
286
287
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
288
289
290
291
292
293
294
295
296
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

297
        h = w = int(vit_embeds.shape[1] ** 0.5)
298
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
299
300
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
301
302
303
304
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
305
        self, **kwargs: object
306
    ) -> SkyworkR1VImageInputs | None:
307
308
309
310
311
312
313
314
315
316
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return SkyworkR1VImageEmbeddingInputs(
                type="image_embeds",
317
                data=image_embeds,
318
319
320
            )

        image_token_id = kwargs["image_token_id"]
321
322
323
324
325
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
326
327
328
329

        if pixel_values_flat is not None:
            return SkyworkR1VImagePixelInputs(
                type="pixel_values",
330
                pixel_values_flat=pixel_values_flat,
331
                num_patches=image_num_patches,
332
333
334
                resolve_bindings={
                    "h": self.config.vision_config.image_size,
                    "w": self.config.vision_config.image_size,
335
336
                },
            )
337
338
339
340
341
342

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: SkyworkR1VImageInputs,
343
    ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
344
345
346
347
348
349
350
351
352
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_embeds = self.extract_feature(image_input["pixel_values_flat"])

        num_patches = image_input["num_patches"]

        # Only one image in the current batch
        if len(num_patches) == 1:
353
354
355
            return image_embeds.view(-1, self.config.text_config.hidden_size).unsqueeze(
                0
            )
356
357
358
359

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
360
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
361
362
363
364
365
366
367
        image_feature_sizes = [
            num_patches * feature_size for num_patches in num_patches
        ]
        return image_embeds.split(image_feature_sizes)

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        if self.is_mono:
368
369
370
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
371
372
373
        else:
            self.visual_token_mask = None

374
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
375
376
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
377
            return []
378

379
        return self._process_image_input(image_input)
380

381
    def embed_input_ids(
382
383
        self,
        input_ids: torch.Tensor,
384
        multimodal_embeddings: MultiModalEmbeddings | None = None,
385
        *,
386
        is_multimodal: torch.Tensor | None = None,
387
    ) -> torch.Tensor:
388
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
389
            self._set_visual_token_mask(input_ids)
390
391
392

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
393
            return super().embed_input_ids(input_ids)
394

395
        return super().embed_input_ids(
396
397
398
399
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
400
401
402

    def forward(
        self,
403
        input_ids: torch.Tensor | None,
404
        positions: torch.Tensor,
405
406
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
407
        **kwargs: object,
408
    ) -> IntermediateTensors:
409
410
411
412
413
414
415
416
417
418
419
420
        if intermediate_tensors is not None:
            inputs_embeds = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        # Only required if the model is mono-architecture
        if self.visual_token_mask is not None:
421
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
422
423
424
425
426
427
428
429
            self.visual_token_mask = None

        hidden_states = self.language_model.model(**forward_kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
430
    ) -> torch.Tensor | None:
431
        return self.language_model.compute_logits(hidden_states)
432

433
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
434
        skip_prefixes = [
435
436
437
438
439
440
441
442
443
444
445
446
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
447
448
449
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
        return loader.load_weights(weights)