glm4v.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

# Adapted from
5
# https://github.com/zai-org/CogAgent
6
"""Inference-only CogAgent model compatible with THUDM weights."""
7

8
from argparse import Namespace
9
from collections.abc import Iterator, Mapping, Sequence
10
from typing import Annotated, Literal
11

12
import numpy as np
13
14
15
16
17
import torch
from torch import nn
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
18
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
19
20
21
22
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput

from vllm.config import VllmConfig
23
from vllm.config.multimodal import BaseDummyOptions
24
25
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
26
from vllm.model_executor.layers.attention import MMEncoderAttention
27
from vllm.model_executor.layers.conv import Conv2dLayer
28
29
30
31
32
33
34
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
35
36
37
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
38
39
from vllm.multimodal.inputs import (
    MultiModalDataDict,
40
    MultiModalFeatureSpec,
41
42
43
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
44
from vllm.multimodal.parse import MultiModalDataItems
45
from vllm.multimodal.processing import (
46
    BaseDummyInputsBuilder,
47
48
49
50
51
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
52
53
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
54
from vllm.utils.tensor_schema import TensorSchema, TensorShape
55

56
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
57
58
59
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
60
    SupportsMRoPE,
61
62
63
    SupportsMultiModal,
    SupportsPP,
)
64
65


66
67
68
69
70
71
72
73
class GLMVImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - c: Number of channels (3)
        - h: Height of image
        - w: Width of image
    """
74

75
76
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
77
78
79
80
81


class EVA2CLIPPatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
82
        self.proj = Conv2dLayer(
83
84
85
86
87
            config.in_channels,
            config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )
88
        self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
89
        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
90
91
92
93
94
95
96
97
98
99
100

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
101
        images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype)
102
103
104
105
106
107
108
109
110
111
112
113
        x = self.proj(images)
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embedding.weight.unsqueeze(0)
        return x


class EVA2CLIPAttention(nn.Module):
    def __init__(
        self,
        config,
114
        quant_config: QuantizationConfig | None = None,
115
        prefix: str = "",
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_rank = config.num_heads // self.tp_size
        self.head_dim = config.hidden_size // config.num_heads
        self.scale = self.head_dim**-0.5

        self.query_key_value = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            config.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.query_key_value",
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )

138
        self.attn = MMEncoderAttention(
139
140
141
            self.num_heads_per_rank,
            self.head_dim,
            self.scale,
142
            prefix=f"{prefix}.attn",
143
        )
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        self.output_dropout = torch.nn.Dropout(config.dropout_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        qkv, _ = self.query_key_value(x)  # B, L, 3 * H * D
        q, k, v = qkv.chunk(3, dim=-1)

        out = self.attn(q, k, v)
        output, _ = self.dense(out)
        output = self.output_dropout(output)
        return output


class EVA2CLIPMLP(nn.Module):
    def __init__(
        self,
        config,
160
        quant_config: QuantizationConfig | None = None,
161
        prefix: str = "",
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    ):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.activation_fn(x)
        x, _ = self.fc2(x)
        return x


class EVA2CLIPTransformerLayer(nn.Module):
    def __init__(
        self,
        config,
190
        quant_config: QuantizationConfig | None = None,
191
        prefix: str = "",
192
193
    ):
        super().__init__()
194
195
196
197
198
199
200
201
202
203
        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = EVA2CLIPAttention(
            config, quant_config=quant_config, prefix=f"{prefix}.attention"
        )
        self.mlp = EVA2CLIPMLP(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
        self.post_attention_layernorm = LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
204
205
206

    def forward(self, hidden_states):
        attention_input = hidden_states
207
        attention_output = self.input_layernorm(self.attention(attention_input))
208
209
210
211
212
213
214
215
216
217
218
        hidden_states = attention_input + attention_output
        mlp_input = hidden_states
        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
        output = mlp_input + mlp_output
        return output


class EVA2CLIPTransformer(nn.Module):
    def __init__(
        self,
        config,
219
        quant_config: QuantizationConfig | None = None,
220
        prefix: str = "",
221
222
    ):
        super().__init__()
223
224
225
226
227
228
229
230
231
232
        self.layers = nn.ModuleList(
            [
                EVA2CLIPTransformerLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
233
234
235
236
237
238
239
240
241
242
243
244

    def forward(self, hidden_states):
        for layer_module in self.layers:
            hidden_states = layer_module(hidden_states)
        return hidden_states


class EVA2CLIPGLU(nn.Module):
    def __init__(
        self,
        config,
        in_features,
245
        quant_config: QuantizationConfig | None = None,
246
        prefix: str = "",
247
248
249
250
251
252
253
254
    ):
        """
        The original implementation is the same as:
        ```python
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
255
            quant_config=quant_config,
256
257
258
259
260
261
        )

        self.gate_proj = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
262
            quant_config=quant_config,
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        )
        ```
        ```
        gate_proj_output, _ = self.gate_proj(x)
        dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
        x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
        ```

        We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
        ```
        self.merged_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
            bias=False,
277
            quant_config=quant_config,
278
279
280
281
282
283
284
        )
        ```
        ```
        x, _ = self.merged_proj(x)
        ```
        """
        super().__init__()
285
286
287
288
289
290
291
        self.linear_proj = ReplicatedLinear(
            in_features,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
292
293
294
295
296
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.act1 = nn.GELU()
        self.act2 = SiluAndMul()

        self.merged_proj = MergedColumnParallelLinear(
297
298
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
299
300
            bias=False,
            quant_config=quant_config,
301
302
            prefix=f"{prefix}.merged_proj",
        )
303
304
305
306
307
308

        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
309
310
            prefix=f"{prefix}.dense_4h_to_h",
        )
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    def forward(self, x):
        x, _ = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x, _ = self.merged_proj(x)
        x = self.act2(x)
        x, _ = self.dense_4h_to_h(x)
        return x


class EVA2CLIPModel(nn.Module):
    def __init__(
        self,
        config,
325
        quant_config: QuantizationConfig | None = None,
326
        prefix: str = "",
327
328
329
330
    ):
        super().__init__()
        vision_config = Namespace(**config.vision_config)
        self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config)
331
332
333
334
335
336
337
338
339
        self.transformer = EVA2CLIPTransformer(
            vision_config, quant_config=quant_config, prefix=f"{prefix}.transformer"
        )
        self.linear_proj = EVA2CLIPGLU(
            config,
            in_features=config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
340
        self.conv = Conv2dLayer(
341
342
343
344
345
            in_channels=vision_config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=2,
            stride=2,
        )
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.scaling_factor = vision_config.scaling_factor

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
        x = self.patch_embedding(images)
        x = self.transformer(x)
        x = x[:, 1:]

        b, s, h = x.shape
        grid_size = int(s**0.5)
        x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
        x = self.conv(x)

        x = x.flatten(2).transpose(1, 2)
        x = self.linear_proj(x)
        boi = self.boi.expand(x.shape[0], -1, -1)
        eoi = self.eoi.expand(x.shape[0], -1, -1)
        x = torch.cat((boi, x, eoi), dim=1)
        x = x / self.scaling_factor
        return x


class GLM4VModel(ChatGLMModel):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        quant_config = vllm_config.quant_config

384
385
386
        self.vision = EVA2CLIPModel(
            self.config, quant_config, prefix=f"{prefix}.vision"
        )
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407


class GLM4VProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
    """

    def __init__(
        self,
        config: ChatGLMConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        vision_config = config.vision_config
        image_size = vision_config["image_size"]

408
409
410
411
412
413
414
415
416
417
418
419
420
        self.image_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )
421
422
423

    def __call__(
        self,
424
425
426
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )


class GLM4VProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(ChatGLMConfig)

458
459
460
461
462
463
    def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
        return self.ctx.init_processor(
            GLM4VProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
464
465
        )

466
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        return {"image": 1}

    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config

        image_size = vision_config["image_size"]
        patch_size = vision_config["patch_size"]
        grid_length = image_size // patch_size // 2
        return grid_length * grid_length

    def get_num_image_feature_tokens(self) -> int:
        # EVA2CLIPModel has embeddings for boi and eoi tokens as well
        return self.get_num_image_tokens() + 2


class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
484
485
486
487
488
489
490
491
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"

        return base_text * num_images

    def get_dummy_mm_data(
492
493
494
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
495
        mm_options: Mapping[str, BaseDummyOptions],
496
    ) -> MultiModalDataDict:
497
498
499
500
501
502
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config

        target_width = target_height = vision_config["image_size"]
        num_images = mm_counts.get("image", 0)

503
        image_overrides = mm_options.get("image")
504

505
        return {
506
507
508
509
510
511
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
512
513
514
515
        }


class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
516
    def _hf_processor_applies_updates(
517
518
519
520
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
521
        tokenization_kwargs: Mapping[str, object],
522
523
524
    ) -> bool:
        return False

525
526
527
528
529
530
531
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

532
    def _get_prompt_updates(
533
534
535
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
536
        out_mm_kwargs: MultiModalKwargsItems,
537
    ) -> Sequence[PromptUpdate]:
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        hf_config = self.info.get_hf_config()

        boi_token_id = hf_config.boi_token_id
        image_token_id = hf_config.pad_token_id
        eoi_token_id = hf_config.eoi_token_id

        def get_replacement(item_idx: int):
            num_image_tokens = self.info.get_num_image_tokens()
            image_tokens = [image_token_id] * num_image_tokens

            return [boi_token_id] + image_tokens + [eoi_token_id]

        return [
            PromptReplacement(
                modality="image",
                target=[boi_token_id, image_token_id, eoi_token_id],
                replacement=get_replacement,
            ),
        ]


559
560
561
562
563
@MULTIMODAL_REGISTRY.register_processor(
    GLM4VMultiModalProcessor,
    info=GLM4VProcessingInfo,
    dummy_inputs=GLM4VDummyInputsBuilder,
)
564
565
566
class GLM4VForCausalLM(
    ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
567
568
569
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
570
        "merged_proj": ["gate_proj", "dense_h_to_4h"],
571
572
573
574
575
576
577
578
579
    }

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
580
581
            tower_model="transformer.vision.transformer",
        )
582

583
    @classmethod
584
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
585
586
587
588
589
        if modality.startswith("image"):
            return "<|begin_of_image|><|endoftext|><|end_of_image|>"

        raise ValueError("Only image modality is supported")

590
591
592
593
594
595
596
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[GLM4VModel] = GLM4VModel,
    ) -> None:
597
598
599
600
601
602
603
604
605
606
        with self._mark_composite_model(
            vllm_config,
            language_targets=GLMTransformer,
            tower_targets={"image": EVA2CLIPModel},
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                transformer_type=transformer_type,
            )
607
608
609
610

        self.transformer: GLM4VModel

    def _parse_and_validate_image_input(
611
        self, **kwargs: object
612
    ) -> GLMVImagePixelInputs | None:
613
614
615
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is not None:
616
            expected_h = expected_w = self.config.vision_config["image_size"]
617
618
619
620
621
            return GLMVImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                resolve_bindings={"h": expected_h, "w": expected_w},
            )
622
623
624

        return None

625
    def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
626
        pixel_values = image_input["data"].to(dtype=self.config.dtype)
627
628
629

        return self.transformer.vision(pixel_values)

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int]]:
        hf_config = self.config
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, t, h // spatial_merge_size, w // spatial_merge_size
            else:
                # glm4v only supports image modality
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

645
    def get_mrope_input_positions(
646
        self,
647
        input_tokens: list[int],
648
        mm_features: list[MultiModalFeatureSpec],
649
650
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        st = 0
        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
                3, -1
            )
            llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            # EVA2CLIPModel has embeddings for boi and eoi tokens as well
            st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1

        if st < len(input_tokens):
            text_len = len(input_tokens) - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
676

677
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
678
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
679
        return torch.from_numpy(llm_positions), mrope_position_delta
680

681
    embed_input_ids = SupportsMultiModal.embed_input_ids
682

683
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
684
685
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
686
            return []
687
688
689
690
691
692

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
693
        input_ids: torch.Tensor | None,
694
        positions: torch.Tensor,
695
696
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
697
        **kwargs: object,
698
    ) -> torch.Tensor | IntermediateTensors:
699
700
701
        if intermediate_tensors is not None:
            inputs_embeds = None

702
703
704
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
705
706

        return hidden_states