step3_vl.py 23.7 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
5
from math import sqrt
6
from typing import Annotated, Any, Literal, TypeAlias
Song's avatar
Song committed
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
11
from transformers import BatchFeature
Song's avatar
Song committed
12
13

from vllm.config import VllmConfig
14
from vllm.config.multimodal import BaseDummyOptions
Song's avatar
Song committed
15
16
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
17
from vllm.model_executor.layers.attention import MMEncoderAttention
18
from vllm.model_executor.layers.conv import Conv2dLayer
19
20
21
22
23
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Song's avatar
Song committed
24
25
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
Song's avatar
Song committed
31
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
32
from vllm.multimodal.processing import (
33
    BaseDummyInputsBuilder,
34
35
36
37
38
39
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
Song's avatar
Song committed
40
from vllm.sequence import IntermediateTensors
41
42
from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig
from vllm.transformers_utils.processors.step3_vl import Step3VLProcessor
43
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Song's avatar
Song committed
44
45

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
46
47
48
49
50
51
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
52
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
Song's avatar
Song committed
53
54


55
56
57
58
59
60
61
62
63
64
65
66
class Step3VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bnp: Batch size * number of images * number of patches
        - hp: Height of patch
        - wp: Width of patch
    """

Song's avatar
Song committed
67
    type: Literal["pixel_values"]
68
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
69
    patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
70
71
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]

Song's avatar
Song committed
72

73
74
75
76
77
78
79
class Step3VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
    """
Song's avatar
Song committed
80

81
82
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Song's avatar
Song committed
83
84


85
Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs
Song's avatar
Song committed
86
87
88
89
90
91
92
93
94


class Step3VLProcessingInfo(BaseProcessingInfo):
    def get_hf_processor(self) -> Step3VLProcessor:
        return Step3VLProcessor(
            self.get_hf_config(),
            self.get_tokenizer(),
        )

95
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Song's avatar
Song committed
96
97
98
99
100
101
        return {"image": None}

    def get_max_image_tokens(self) -> int:
        hf_processor = self.get_hf_processor()
        return hf_processor.get_num_image_tokens(
            self.get_image_size_with_most_features().width,
102
103
            self.get_image_size_with_most_features().height,
        )
Song's avatar
Song committed
104
105
106
107
108
109
110
111
112
113
114
115
116

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_image_size_with_most_features(self) -> ImageSize:
        return ImageSize(3024, 3024)

    def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
        if len(mm_data) != 1 or "image" not in mm_data:
117
            raise ValueError("mm_data could only contain one key 'image' for steo1o")
Song's avatar
Song committed
118
119
120
121
122

        image_data = mm_data["image"]
        if not isinstance(image_data, (list, tuple)):
            image_data = [image_data]

123
124
125
126
        return sum(
            self.get_hf_processor().get_num_image_tokens(img.width, img.height)
            for img in image_data
        )
Song's avatar
Song committed
127
128
129
130
131
132
133
134
135
136
137


class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return "<im_patch>" * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
138
        mm_options: Mapping[str, BaseDummyOptions],
Song's avatar
Song committed
139
    ) -> MultiModalDataDict:
140
        target_width, target_height = self.info.get_image_size_with_most_features()
Song's avatar
Song committed
141
142
        num_images = mm_counts.get("image", 0)

143
        image_overrides = mm_options.get("image")
144

Song's avatar
Song committed
145
        return {
146
147
148
149
150
151
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
Song's avatar
Song committed
152
153
154
        }


155
class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]):
Song's avatar
Song committed
156
157
158
159
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
160
        out_mm_kwargs: MultiModalKwargsItems,
Song's avatar
Song committed
161
162
163
164
165
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_placeholder_token_id = hf_processor.image_token_id

        def get_replacement_step1o(item_idx: int):
166
167
            out_item = out_mm_kwargs["image"][item_idx]
            num_patches = int(out_item["num_patches"].data)
Song's avatar
Song committed
168
            if num_patches > 0:
169
                patch_newline_mask = out_item["patch_newline_mask"].data
Song's avatar
Song committed
170
                image_repl_ids = hf_processor._get_image_repl_features(
171
172
                    1, num_patches, patch_newline_mask.tolist()
                )[1]
Song's avatar
Song committed
173
            else:
174
                image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1]
Song's avatar
Song committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            return PromptUpdateDetails.select_token_id(
                seq=image_repl_ids,
                embed_token_id=image_placeholder_token_id,
            )

        return [
            PromptReplacement(
                modality="image",
                target=[image_placeholder_token_id],
                replacement=get_replacement_step1o,
            )
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
198
199
                "image", num_patches
            ),
Song's avatar
Song committed
200
201
            num_patches=MultiModalFieldConfig.batched("image"),
            patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
202
203
                "image", num_patches
            ),
Song's avatar
Song committed
204
205
206
207
208
209
210
211
212
213
214
215
216
        )


def get_abs_pos(abs_pos, tgt_size):
    dim = abs_pos.size(-1)
    abs_pos_new = abs_pos.squeeze(0)
    cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]

    src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
217
218
219
220
221
        old_pos_embed = (
            old_pos_embed.view(1, src_size, src_size, dim)
            .permute(0, 3, 1, 2)
            .contiguous()
        )
Song's avatar
Song committed
222
223
224
225
        old_pos_embed = old_pos_embed.to(torch.float32)
        new_pos_embed = F.interpolate(
            old_pos_embed,
            size=(tgt_size, tgt_size),
226
            mode="bicubic",
Song's avatar
Song committed
227
228
229
230
231
232
            antialias=True,
            align_corners=False,
        ).to(dtype)
        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
        new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
        vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
233
        vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
Song's avatar
Song committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        return vision_pos_embed
    else:
        return abs_pos


class Step3VisionEmbeddings(nn.Module):
    def __init__(self, config: Step3VisionEncoderConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))

249
        self.patch_embedding = Conv2dLayer(
Song's avatar
Song committed
250
251
252
253
254
255
256
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=True,
        )

257
        self.num_patches = (self.image_size // self.patch_size) ** 2
Song's avatar
Song committed
258
259
        self.pad_tp_size = 4  # hard code for padding
        # To load the pretrained weights, we still use P+1 as the seqlen
260
261
262
263
264
265
266
267
        self.position_embedding = torch.nn.Embedding(
            self.num_patches + 1, self.embed_dim
        )
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_patches + 1).expand((1, -1)),
            persistent=False,
        )
Song's avatar
Song committed
268
269
270
271

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        patch_embeds = self.patch_embedding(
272
273
            pixel_values
        )  # shape = [*, width, grid, grid]
Song's avatar
Song committed
274
275
276
277
278
279
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        # pad
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + get_abs_pos(
280
281
282
283
284
285
286
287
288
            self.position_embedding(self.position_ids), patch_embeds.size(1)
        )
        embeddings = torch.cat(
            [
                embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
                embeddings,
            ],
            dim=1,
        )
Song's avatar
Song committed
289
290
291
292
293
294
        return embeddings


class Step3VisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

295
296
297
    def __init__(
        self,
        config,
298
        quant_config: QuantizationConfig | None = None,
299
300
        prefix: str = "",
    ):
Song's avatar
Song committed
301
302
303
304
305
306
307
308
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads

        self.scale = self.head_dim**-0.5

309
        use_data_parallel = is_vit_use_data_parallel()
310
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
Song's avatar
Song committed
311
312
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
313
314
315

        self.q_size = self.num_heads * self.head_dim

316
317
318
319
320
321
322
323
324
        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
325
326
327
328
329
330
331
332
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )
Song's avatar
Song committed
333

334
        # Use unified MMEncoderAttention with automatic backend selection
335
336
337
338
        self.attn = MMEncoderAttention(
            self.num_heads,
            self.head_dim,
            self.scale,
339
            prefix=f"{prefix}.attn",
340
        )
Song's avatar
Song committed
341
342
343
344
345
346
347
348
349
350
351

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
352

353
        # Use unified MMEncoderAttention with automatic backend selection
354
        attn_output = self.attn(q, k, v)
Song's avatar
Song committed
355
356
357
358
359
360
361

        attn_output, _ = self.out_proj(attn_output)

        return attn_output


class Step3VisionMLP(nn.Module):
362
363
364
    def __init__(
        self,
        config,
365
        quant_config: QuantizationConfig | None = None,
366
367
        prefix: str = "",
    ):
Song's avatar
Song committed
368
369
370
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
371
        use_data_parallel = is_vit_use_data_parallel()
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Song's avatar
Song committed
388
389
390
391
392
393
394
395
396

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class Step3VisionEncoderLayer(nn.Module):
397
398
399
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
400
        quant_config: QuantizationConfig | None = None,
401
402
        prefix: str = "",
    ):
Song's avatar
Song committed
403
404
        super().__init__()
        self.embed_dim = config.hidden_size
405
406
407
408
        self.self_attn = Step3VisionAttention(
            config,
            quant_config,
            prefix=f"{prefix}.self_attn",
409
410
411
412
413
414
415
416
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Step3VisionMLP(
            config,
            quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
Song's avatar
Song committed
417
418
419
420
421

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.FloatTensor:
422
423
        hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
        hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
Song's avatar
Song committed
424
425
426
427
        return hidden_states


class Step3VisionEncoder(nn.Module):
428
429
430
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
431
        quant_config: QuantizationConfig | None = None,
432
433
        prefix: str = "",
    ):
Song's avatar
Song committed
434
435
        super().__init__()
        self.config = config
436
437
438
439
440
441
442
443
444
445
        self.layers = nn.ModuleList(
            [
                Step3VisionEncoderLayer(
                    config,
                    quant_config,
                    prefix=f"{prefix}.layers.{i}",
                )
                for i in range(config.num_hidden_layers)
            ]
        )
Song's avatar
Song committed
446
447
448
449
450
451
452
453
454
455
456
457

    def forward(
        self,
        inputs_embeds,
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
        return hidden_states


class Step3VisionTransformer(nn.Module):
458
459
460
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
461
        quant_config: QuantizationConfig | None = None,
462
463
        prefix: str = "",
    ):
Song's avatar
Song committed
464
465
        super().__init__()
        self.config = config
466
        self.use_data_parallel = is_vit_use_data_parallel()
Song's avatar
Song committed
467
468
        self.image_size = config.image_size
        self.embeddings = Step3VisionEmbeddings(config)
469
470
471
472
        self.transformer = Step3VisionEncoder(
            config,
            quant_config,
            prefix=f"{prefix}.transformer",
473
        )
Song's avatar
Song committed
474
475
476
477
478
479

    def forward(
        self,
        pixel_values: torch.Tensor,
    ):
        hidden_states = self.embeddings(pixel_values)
480
        if self.use_data_parallel:
481
            hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer)
482
483
        else:
            hidden_states = self.transformer(inputs_embeds=hidden_states)
Song's avatar
Song committed
484
485
486
        return hidden_states


487
488
489
490
491
492
493
494
495
496
497
498
@MULTIMODAL_REGISTRY.register_processor(
    Step3VLMultiModalProcessor,
    info=Step3VLProcessingInfo,
    dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        }
    )
Song's avatar
Song committed
499

500
501
    supports_encoder_tp_data = True

Song's avatar
Song committed
502
    @classmethod
503
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Song's avatar
Song committed
504
505
506
507
508
509
510
511
512
513
514
515
        if modality.startswith("image"):
            return "<im_patch>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
516
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Song's avatar
Song committed
517

518
519
520
521
522
523
524
525
526
527
528
529
530
        # NOTE: This behavior is consistent with the previous OOV handling,
        # but does not currently handle the start/stop toks around the
        # image features (<patch_start> <patch_end> <im_start> <im_end>)
        # See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
        #
        # If this becomes an issue or we refactor to handle this using the
        # processor info in the future, it would probably be best to handle
        # those too.
        self.configure_mm_token_handling(
            self.config.text_config.vocab_size,
            [self.config.image_token_id],
        )

531
        with self._mark_tower_model(vllm_config, "image"):
532
533
534
535
            self.vision_model = Step3VisionTransformer(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
536
            )
537
            self.vit_downsampler = Conv2dLayer(
538
539
540
                config.vision_config.hidden_size,
                config.vision_config.output_hidden_size,
                kernel_size=2,
541
542
                stride=config.understand_projector_stride,
            )
543
            self.vit_downsampler2 = Conv2dLayer(
544
545
546
547
548
549
550
551
552
553
554
                config.vision_config.output_hidden_size,
                config.vision_config.output_hidden_size * 2,
                kernel_size=3,
                stride=2,
                padding=1,
            )
            self.vit_large_projector = nn.Linear(
                config.vision_config.output_hidden_size * 2,
                config.hidden_size,
                bias=config.projector_bias,
            )
555
556
557
558
559
560
561

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
Song's avatar
Song committed
562
563

        self.make_empty_intermediate_tensors = (
564
565
            self.language_model.make_empty_intermediate_tensors
        )
Song's avatar
Song committed
566
567
568
569
570
571
572
573
574
575

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def _parse_and_validate_image_input(
576
        self, **kwargs: object
577
    ) -> Step3VLImageInputs | None:
Song's avatar
Song committed
578
579
580
581
582
583
584
585
        pixel_values = kwargs.pop("pixel_values", None)
        patch_pixel_values = kwargs.pop("patch_pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

586
        if pixel_values is not None and patch_pixel_values is not None:
Song's avatar
Song committed
587
588
            return Step3VLImagePixelInputs(
                type="pixel_values",
589
                pixel_values=pixel_values.to(self.dtype),
590
                patch_pixel_values=patch_pixel_values.to(self.dtype),
Song's avatar
Song committed
591
592
593
594
595
596
                num_patches=num_patches,
            )

        if image_embeds is not None:
            return Step3VLImageEmbeddingInputs(
                type="image_embeds",
597
                image_embeds=image_embeds.to(self.dtype),
Song's avatar
Song committed
598
            )
599
600

        raise AssertionError("This line should be unreachable.")
Song's avatar
Song committed
601

602
    def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
Song's avatar
Song committed
603
604
605
606
607
608
609
610
611
612
        B, P = image_features.shape[:2]
        HW = int(sqrt(P))
        image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
        image_features = self.vit_downsampler(image_features)
        image_features = self.vit_downsampler2(image_features)
        n_dim = image_features.size(1)
        image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
        image_features = self.vit_large_projector(image_features)
        return image_features

613
    def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
Song's avatar
Song committed
614
615
616
        return self.vision_model(input_tensor)[:, 4:]

    def _process_image_input(
617
618
        self, image_input: Step3VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
Song's avatar
Song committed
619
620
621
        if image_input["type"] == "image_embeds":
            image_features = image_input["image_embeds"]
        else:
622
623
624
            image_features = self._get_vision_model_output(image_input["pixel_values"])
            patch_image_features = (
                self._get_vision_model_output(image_input["patch_pixel_values"])
625
                if len(image_input["patch_pixel_values"]) > 0
626
627
                else None
            )
Song's avatar
Song committed
628
629
630
            num_patches = image_input["num_patches"]

        image_features = self._process_image_features(image_features)
631
632
633
634
635
        patch_image_features = (
            self._process_image_features(patch_image_features)
            if patch_image_features is not None
            else None
        )
Song's avatar
Song committed
636
637
638
639
640
641
642

        merged_image_features = []
        cur_patch_idx = 0
        for i, num_patch in enumerate(num_patches):
            cur_feature = []
            if num_patch > 0:
                patch_slice = patch_image_features[
643
644
                    cur_patch_idx : cur_patch_idx + num_patch
                ]
Song's avatar
Song committed
645
                cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
646
            cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
Song's avatar
Song committed
647
648
            cur_patch_idx += num_patch
            merged_image_features.append(
649
650
                torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
            )
Song's avatar
Song committed
651
652
        return merged_image_features

653
    def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
Song's avatar
Song committed
654
655
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
656
            return []
Song's avatar
Song committed
657
658
659
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

660
    def embed_input_ids(
Song's avatar
Song committed
661
662
        self,
        input_ids: torch.Tensor,
663
        multimodal_embeddings: MultiModalEmbeddings | None = None,
664
        *,
665
        is_multimodal: torch.Tensor | None = None,
Song's avatar
Song committed
666
    ) -> torch.Tensor:
667
668
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
669
            return super().embed_input_ids(input_ids)
670

671
        return super().embed_input_ids(
672
673
674
675
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
Song's avatar
Song committed
676
677
678

    def forward(
        self,
679
        input_ids: torch.Tensor | None,
Song's avatar
Song committed
680
        positions: torch.Tensor,
681
682
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Song's avatar
Song committed
683
        **kwargs: object,
684
    ) -> torch.Tensor | IntermediateTensors:
Song's avatar
Song committed
685
686
687
        if intermediate_tensors is not None:
            inputs_embeds = None

688
689
690
        hidden_states = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Song's avatar
Song committed
691
692
693
694
695
696

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
697
    ) -> torch.Tensor | None:
698
        return self.language_model.compute_logits(hidden_states)
Song's avatar
Song committed
699
700

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
701
702
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)