step3_vl.py 23.4 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
5
from math import sqrt
6
from typing import Annotated, Any, Literal, TypeAlias
Song's avatar
Song committed
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
11
from transformers import BatchFeature
Song's avatar
Song committed
12
13

from vllm.config import VllmConfig
14
from vllm.config.multimodal import BaseDummyOptions
Song's avatar
Song committed
15
16
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
17
from vllm.model_executor.layers.attention import MMEncoderAttention
18
from vllm.model_executor.layers.conv import Conv2dLayer
19
20
21
22
23
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Song's avatar
Song committed
24
25
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
Song's avatar
Song committed
31
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
32
from vllm.multimodal.processing import (
33
    BaseDummyInputsBuilder,
34
35
36
37
38
39
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
Song's avatar
Song committed
40
from vllm.sequence import IntermediateTensors
41
from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig
42
43
44
45
46
from vllm.transformers_utils.processors.step3_vl import (
    MAX_IMAGE_SIZE,
    Step3VLImageProcessor,
    Step3VLProcessor,
)
47
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Song's avatar
Song committed
48
49

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
50
51
52
53
54
55
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
56
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
Song's avatar
Song committed
57
58


59
60
61
62
63
64
65
66
67
68
69
70
class Step3VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bnp: Batch size * number of images * number of patches
        - hp: Height of patch
        - wp: Width of patch
    """

Song's avatar
Song committed
71
    type: Literal["pixel_values"]
72
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
73
    patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
74
75
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]

Song's avatar
Song committed
76

77
78
79
80
81
82
83
class Step3VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
    """
Song's avatar
Song committed
84

85
86
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Song's avatar
Song committed
87
88


89
Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs
Song's avatar
Song committed
90
91
92


class Step3VLProcessingInfo(BaseProcessingInfo):
93
94
95
96
97
98
99
100
101
102
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()

        kwargs.setdefault(
            "enable_patch",
            getattr(config.vision_config, "enable_patch", True),
        )

        return Step3VLImageProcessor(**kwargs)

Song's avatar
Song committed
103
104
    def get_hf_processor(self) -> Step3VLProcessor:
        return Step3VLProcessor(
105
106
            tokenizer=self.get_tokenizer(),
            image_processor=self.get_image_processor(),
Song's avatar
Song committed
107
108
        )

109
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Song's avatar
Song committed
110
111
112
        return {"image": None}

    def get_max_image_tokens(self) -> int:
113
114
115
116
        image_processor = self.get_image_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return image_processor.get_num_image_tokens(target_width, target_height)
Song's avatar
Song committed
117
118
119
120
121
122
123
124
125

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_image_size_with_most_features(self) -> ImageSize:
126
        return ImageSize(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)
Song's avatar
Song committed
127
128
129
130
131
132
133
134
135
136
137


class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return "<im_patch>" * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
138
        mm_options: Mapping[str, BaseDummyOptions],
Song's avatar
Song committed
139
    ) -> MultiModalDataDict:
140
        target_width, target_height = self.info.get_image_size_with_most_features()
Song's avatar
Song committed
141
142
        num_images = mm_counts.get("image", 0)

143
        image_overrides = mm_options.get("image")
144

Song's avatar
Song committed
145
        return {
146
147
148
149
150
151
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
Song's avatar
Song committed
152
153
154
        }


155
class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]):
Song's avatar
Song committed
156
157
158
159
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
160
        out_mm_kwargs: MultiModalKwargsItems,
Song's avatar
Song committed
161
162
163
164
165
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_placeholder_token_id = hf_processor.image_token_id

        def get_replacement_step1o(item_idx: int):
166
167
            out_item = out_mm_kwargs["image"][item_idx]
            num_patches = int(out_item["num_patches"].data)
168
169
170
171
172
            patch_newline_mask = out_item["patch_newline_mask"].data
            image_repl_ids = hf_processor.get_image_repl_feature_ids(
                1, num_patches, patch_newline_mask.tolist()
            )

Song's avatar
Song committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            return PromptUpdateDetails.select_token_id(
                seq=image_repl_ids,
                embed_token_id=image_placeholder_token_id,
            )

        return [
            PromptReplacement(
                modality="image",
                target=[image_placeholder_token_id],
                replacement=get_replacement_step1o,
            )
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            patch_pixel_values=MultiModalFieldConfig.flat_from_sizes(
196
197
                "image", num_patches
            ),
Song's avatar
Song committed
198
199
            num_patches=MultiModalFieldConfig.batched("image"),
            patch_newline_mask=MultiModalFieldConfig.flat_from_sizes(
200
201
                "image", num_patches
            ),
Song's avatar
Song committed
202
203
204
205
206
207
208
209
210
211
212
213
214
        )


def get_abs_pos(abs_pos, tgt_size):
    dim = abs_pos.size(-1)
    abs_pos_new = abs_pos.squeeze(0)
    cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]

    src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
215
216
217
218
219
        old_pos_embed = (
            old_pos_embed.view(1, src_size, src_size, dim)
            .permute(0, 3, 1, 2)
            .contiguous()
        )
Song's avatar
Song committed
220
221
222
223
        old_pos_embed = old_pos_embed.to(torch.float32)
        new_pos_embed = F.interpolate(
            old_pos_embed,
            size=(tgt_size, tgt_size),
224
            mode="bicubic",
Song's avatar
Song committed
225
226
227
228
229
230
            antialias=True,
            align_corners=False,
        ).to(dtype)
        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
        new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
        vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
231
        vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
Song's avatar
Song committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        return vision_pos_embed
    else:
        return abs_pos


class Step3VisionEmbeddings(nn.Module):
    def __init__(self, config: Step3VisionEncoderConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))

247
        self.patch_embedding = Conv2dLayer(
Song's avatar
Song committed
248
249
250
251
252
253
254
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=True,
        )

255
        self.num_patches = (self.image_size // self.patch_size) ** 2
Song's avatar
Song committed
256
257
        self.pad_tp_size = 4  # hard code for padding
        # To load the pretrained weights, we still use P+1 as the seqlen
258
259
260
261
262
263
264
265
        self.position_embedding = torch.nn.Embedding(
            self.num_patches + 1, self.embed_dim
        )
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_patches + 1).expand((1, -1)),
            persistent=False,
        )
Song's avatar
Song committed
266
267
268
269

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        patch_embeds = self.patch_embedding(
270
271
            pixel_values
        )  # shape = [*, width, grid, grid]
Song's avatar
Song committed
272
273
274
275
276
277
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        # pad
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + get_abs_pos(
278
279
280
281
282
283
284
285
286
            self.position_embedding(self.position_ids), patch_embeds.size(1)
        )
        embeddings = torch.cat(
            [
                embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1),
                embeddings,
            ],
            dim=1,
        )
Song's avatar
Song committed
287
288
289
290
291
292
        return embeddings


class Step3VisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

293
294
295
    def __init__(
        self,
        config,
296
        quant_config: QuantizationConfig | None = None,
297
298
        prefix: str = "",
    ):
Song's avatar
Song committed
299
300
301
302
303
304
305
306
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads

        self.scale = self.head_dim**-0.5

307
        use_data_parallel = is_vit_use_data_parallel()
308
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
Song's avatar
Song committed
309
310
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
311
312
313

        self.q_size = self.num_heads * self.head_dim

314
315
316
317
318
319
320
321
322
        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
323
324
325
326
327
328
329
330
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )
Song's avatar
Song committed
331

332
        # Use unified MMEncoderAttention with automatic backend selection
333
334
335
336
        self.attn = MMEncoderAttention(
            self.num_heads,
            self.head_dim,
            self.scale,
337
            prefix=f"{prefix}.attn",
338
        )
Song's avatar
Song committed
339
340
341
342
343
344
345
346
347
348
349

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
350

351
        # Use unified MMEncoderAttention with automatic backend selection
352
        attn_output = self.attn(q, k, v)
Song's avatar
Song committed
353
354
355
356
357
358
359

        attn_output, _ = self.out_proj(attn_output)

        return attn_output


class Step3VisionMLP(nn.Module):
360
361
362
    def __init__(
        self,
        config,
363
        quant_config: QuantizationConfig | None = None,
364
365
        prefix: str = "",
    ):
Song's avatar
Song committed
366
367
368
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
369
        use_data_parallel = is_vit_use_data_parallel()
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Song's avatar
Song committed
386
387
388
389
390
391
392
393
394

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class Step3VisionEncoderLayer(nn.Module):
395
396
397
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
398
        quant_config: QuantizationConfig | None = None,
399
400
        prefix: str = "",
    ):
Song's avatar
Song committed
401
402
        super().__init__()
        self.embed_dim = config.hidden_size
403
404
405
406
        self.self_attn = Step3VisionAttention(
            config,
            quant_config,
            prefix=f"{prefix}.self_attn",
407
408
409
410
411
412
413
414
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Step3VisionMLP(
            config,
            quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
Song's avatar
Song committed
415
416
417
418
419

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.FloatTensor:
420
421
        hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states))
        hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states))
Song's avatar
Song committed
422
423
424
425
        return hidden_states


class Step3VisionEncoder(nn.Module):
426
427
428
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
429
        quant_config: QuantizationConfig | None = None,
430
431
        prefix: str = "",
    ):
Song's avatar
Song committed
432
433
        super().__init__()
        self.config = config
434
435
436
437
438
439
440
441
442
443
        self.layers = nn.ModuleList(
            [
                Step3VisionEncoderLayer(
                    config,
                    quant_config,
                    prefix=f"{prefix}.layers.{i}",
                )
                for i in range(config.num_hidden_layers)
            ]
        )
Song's avatar
Song committed
444
445
446
447
448
449
450
451
452
453
454
455

    def forward(
        self,
        inputs_embeds,
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)
        return hidden_states


class Step3VisionTransformer(nn.Module):
456
457
458
    def __init__(
        self,
        config: Step3VisionEncoderConfig,
459
        quant_config: QuantizationConfig | None = None,
460
461
        prefix: str = "",
    ):
Song's avatar
Song committed
462
463
        super().__init__()
        self.config = config
464
        self.use_data_parallel = is_vit_use_data_parallel()
Song's avatar
Song committed
465
466
        self.image_size = config.image_size
        self.embeddings = Step3VisionEmbeddings(config)
467
468
469
470
        self.transformer = Step3VisionEncoder(
            config,
            quant_config,
            prefix=f"{prefix}.transformer",
471
        )
Song's avatar
Song committed
472
473
474
475
476
477

    def forward(
        self,
        pixel_values: torch.Tensor,
    ):
        hidden_states = self.embeddings(pixel_values)
478
        if self.use_data_parallel:
479
            hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer)
480
481
        else:
            hidden_states = self.transformer(inputs_embeds=hidden_states)
Song's avatar
Song committed
482
483
484
        return hidden_states


485
486
487
488
489
490
491
492
493
494
495
496
@MULTIMODAL_REGISTRY.register_processor(
    Step3VLMultiModalProcessor,
    info=Step3VLProcessingInfo,
    dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        }
    )
Song's avatar
Song committed
497

498
499
    supports_encoder_tp_data = True

Song's avatar
Song committed
500
    @classmethod
501
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Song's avatar
Song committed
502
503
504
505
506
507
508
509
510
511
512
513
        if modality.startswith("image"):
            return "<im_patch>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
514
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Song's avatar
Song committed
515

516
517
518
519
520
521
522
523
524
525
526
527
528
        # NOTE: This behavior is consistent with the previous OOV handling,
        # but does not currently handle the start/stop toks around the
        # image features (<patch_start> <patch_end> <im_start> <im_end>)
        # See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
        #
        # If this becomes an issue or we refactor to handle this using the
        # processor info in the future, it would probably be best to handle
        # those too.
        self.configure_mm_token_handling(
            self.config.text_config.vocab_size,
            [self.config.image_token_id],
        )

529
        with self._mark_tower_model(vllm_config, "image"):
530
531
532
533
            self.vision_model = Step3VisionTransformer(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
534
            )
535
            self.vit_downsampler = Conv2dLayer(
536
537
538
                config.vision_config.hidden_size,
                config.vision_config.output_hidden_size,
                kernel_size=2,
539
540
                stride=config.understand_projector_stride,
            )
541
            self.vit_downsampler2 = Conv2dLayer(
542
543
544
545
546
547
548
549
550
551
552
                config.vision_config.output_hidden_size,
                config.vision_config.output_hidden_size * 2,
                kernel_size=3,
                stride=2,
                padding=1,
            )
            self.vit_large_projector = nn.Linear(
                config.vision_config.output_hidden_size * 2,
                config.hidden_size,
                bias=config.projector_bias,
            )
553
554
555
556
557
558
559

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
Song's avatar
Song committed
560
561

        self.make_empty_intermediate_tensors = (
562
563
            self.language_model.make_empty_intermediate_tensors
        )
Song's avatar
Song committed
564
565
566
567
568
569
570
571
572
573

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def _parse_and_validate_image_input(
574
        self, **kwargs: object
575
    ) -> Step3VLImageInputs | None:
Song's avatar
Song committed
576
577
578
579
580
581
582
583
        pixel_values = kwargs.pop("pixel_values", None)
        patch_pixel_values = kwargs.pop("patch_pixel_values", None)
        num_patches = kwargs.pop("num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

584
        if pixel_values is not None and patch_pixel_values is not None:
Song's avatar
Song committed
585
586
            return Step3VLImagePixelInputs(
                type="pixel_values",
587
                pixel_values=pixel_values.to(self.dtype),
588
                patch_pixel_values=patch_pixel_values.to(self.dtype),
Song's avatar
Song committed
589
590
591
592
593
594
                num_patches=num_patches,
            )

        if image_embeds is not None:
            return Step3VLImageEmbeddingInputs(
                type="image_embeds",
595
                image_embeds=image_embeds.to(self.dtype),
Song's avatar
Song committed
596
            )
597
598

        raise AssertionError("This line should be unreachable.")
Song's avatar
Song committed
599

600
    def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
Song's avatar
Song committed
601
602
603
604
605
606
607
608
609
610
        B, P = image_features.shape[:2]
        HW = int(sqrt(P))
        image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
        image_features = self.vit_downsampler(image_features)
        image_features = self.vit_downsampler2(image_features)
        n_dim = image_features.size(1)
        image_features = image_features.view(B, n_dim, -1).permute(0, 2, 1)
        image_features = self.vit_large_projector(image_features)
        return image_features

611
    def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor:
Song's avatar
Song committed
612
613
614
        return self.vision_model(input_tensor)[:, 4:]

    def _process_image_input(
615
616
        self, image_input: Step3VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
Song's avatar
Song committed
617
618
619
        if image_input["type"] == "image_embeds":
            image_features = image_input["image_embeds"]
        else:
620
621
622
            image_features = self._get_vision_model_output(image_input["pixel_values"])
            patch_image_features = (
                self._get_vision_model_output(image_input["patch_pixel_values"])
623
                if len(image_input["patch_pixel_values"]) > 0
624
625
                else None
            )
Song's avatar
Song committed
626
627
628
            num_patches = image_input["num_patches"]

        image_features = self._process_image_features(image_features)
629
630
631
632
633
        patch_image_features = (
            self._process_image_features(patch_image_features)
            if patch_image_features is not None
            else None
        )
Song's avatar
Song committed
634
635
636
637
638
639
640

        merged_image_features = []
        cur_patch_idx = 0
        for i, num_patch in enumerate(num_patches):
            cur_feature = []
            if num_patch > 0:
                patch_slice = patch_image_features[
641
642
                    cur_patch_idx : cur_patch_idx + num_patch
                ]
Song's avatar
Song committed
643
                cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
644
            cur_feature.append(image_features[i].view(-1, image_features.shape[-1]))
Song's avatar
Song committed
645
646
            cur_patch_idx += num_patch
            merged_image_features.append(
647
648
                torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0]
            )
Song's avatar
Song committed
649
650
        return merged_image_features

651
    def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
Song's avatar
Song committed
652
653
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
654
            return []
Song's avatar
Song committed
655
656
657
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

658
    def embed_input_ids(
Song's avatar
Song committed
659
660
        self,
        input_ids: torch.Tensor,
661
        multimodal_embeddings: MultiModalEmbeddings | None = None,
662
        *,
663
        is_multimodal: torch.Tensor | None = None,
Song's avatar
Song committed
664
    ) -> torch.Tensor:
665
666
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
667
            return super().embed_input_ids(input_ids)
668

669
        return super().embed_input_ids(
670
671
672
673
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
Song's avatar
Song committed
674
675
676

    def forward(
        self,
677
        input_ids: torch.Tensor | None,
Song's avatar
Song committed
678
        positions: torch.Tensor,
679
680
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Song's avatar
Song committed
681
        **kwargs: object,
682
    ) -> torch.Tensor | IntermediateTensors:
Song's avatar
Song committed
683
684
685
        if intermediate_tensors is not None:
            inputs_embeds = None

686
687
688
        hidden_states = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Song's avatar
Song committed
689
690
691
692
693
694

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
695
    ) -> torch.Tensor | None:
696
        return self.language_model.compute_logits(hidden_states)
Song's avatar
Song committed
697
698

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
699
700
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)