tarsier.py 21.2 KB
Newer Older
汪志鹏's avatar
汪志鹏 committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
汪志鹏's avatar
汪志鹏 committed
3
4
5

import math
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
汪志鹏's avatar
汪志鹏 committed
7
8
9

import torch
import torch.nn as nn
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
汪志鹏's avatar
汪志鹏 committed
16
17
18
from transformers import LlavaConfig as HfLlavaConfig
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
from transformers.models.llava import LlavaProcessor
19
from transformers.processing_utils import ProcessingKwargs, Unpack
汪志鹏's avatar
汪志鹏 committed
20
21
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput

22
from vllm.config import VllmConfig
汪志鹏's avatar
汪志鹏 committed
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
汪志鹏's avatar
汪志鹏 committed
25
26
27
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
29
30
31
32
33
34
35
36
37
38
39
40
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
汪志鹏's avatar
汪志鹏 committed
41
from vllm.sequence import IntermediateTensors
42
from vllm.utils.tensor_schema import TensorSchema, TensorShape
汪志鹏's avatar
汪志鹏 committed
43
44
45
46

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
47
48
49
50
51
52
from .utils import (
    AutoWeightsLoader,
    get_layer_index,
    init_vllm_registered_model,
    maybe_prefix,
)
53
54
55
56
57
from .vision import (
    VisionEncoderInfo,
    get_num_selected_vision_tokens,
    get_vision_encoder_info,
)
汪志鹏's avatar
汪志鹏 committed
58
59


60
61
62
63
64
65
66
67
class TarsierImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
68

69
70
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
汪志鹏's avatar
汪志鹏 committed
71
72


73
74
75
76
77
78
79
80
class TarsierImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
81

82
83
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
汪志鹏's avatar
汪志鹏 committed
84
85


86
TarsierImageInputs: TypeAlias = TarsierImagePixelInputs | TarsierImageEmbeddingInputs
汪志鹏's avatar
汪志鹏 committed
87
88
89
90
91
92
93


class TarsierHfConfig(Protocol):  # Based on the Tarsier's LlavaConfig
    vision_config: Final[PretrainedConfig]
    text_config: Final[PretrainedConfig]  # Added from Tarsier's LlavaConfig
    image_token_index: Final[int]
    vision_feature_select_strategy: Final[str]
94
    vision_feature_layer: Final[int | list[int]]
汪志鹏's avatar
汪志鹏 committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    projector_hidden_act: Final[str]
    image_newline_idx: Final[int]
    image_new_idx: Final[int]
    multimodal_projector_bias: bool = True


class TarsierProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "text_kwargs": {
            "padding": False,
        },
        "images_kwargs": {},
    }


class TarsierProcessor(LlavaProcessor):
    def __call__(
        self,
        images: ImageInput = None,
114
115
116
117
        text: TextInput
        | PreTokenizedInput
        | list[TextInput]
        | list[PreTokenizedInput] = None,
汪志鹏's avatar
汪志鹏 committed
118
119
120
121
122
        audio=None,
        videos=None,
        **kwargs: Unpack[TarsierProcessorKwargs],
    ) -> BatchFeature:
        if images is None and text is None:
123
            raise ValueError("You have to specify at least one of `images` or `text`.")
汪志鹏's avatar
汪志鹏 committed
124
125
126
127
128
129
130
131

        output_kwargs = self._merge_kwargs(
            TarsierProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )
        if images is not None:
            image_inputs = self.image_processor(
132
133
                images, **output_kwargs["images_kwargs"]
            )
汪志鹏's avatar
汪志鹏 committed
134
135
136
137
138
139
        else:
            image_inputs = {}

        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
140
141
142
            raise ValueError(
                "Invalid input text. Please provide a string, or a list of strings"
            )
汪志鹏's avatar
汪志鹏 committed
143
144
145
146
147
148
149

        # try to expand inputs in processing if we have the necessary parts
        prompt_strings = text
        if image_inputs.get("pixel_values") is not None:
            # Replace the image token with the expanded image token sequence
            pixel_values = image_inputs["pixel_values"]
            height, width = get_image_size(to_numpy_array(pixel_values[0]))
150
151
152
153
154
            num_image_tokens = (
                (height // self.patch_size) * (width // self.patch_size + 1)
                + self.num_additional_image_tokens
                + 1
            )
汪志鹏's avatar
汪志鹏 committed
155
156
157
158
159
            if self.vision_feature_select_strategy == "default":
                num_image_tokens -= 1

            prompt_strings = []
            for sample in text:
160
161
162
                sample = sample.replace(
                    self.image_token, self.image_token * num_image_tokens
                )
汪志鹏's avatar
汪志鹏 committed
163
164
                prompt_strings.append(sample)

165
166
167
168
169
        return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
        text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
        return BatchFeature(
            data={**text_inputs, **image_inputs}, tensor_type=return_tensors
        )
汪志鹏's avatar
汪志鹏 committed
170
171
172


class TarsierMultiModalProjector(nn.Module):
173
174
175
176
177
178
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
179
        quant_config: QuantizationConfig | None = None,
180
181
        prefix: str = "",
    ):
汪志鹏's avatar
汪志鹏 committed
182
183
        super().__init__()

184
185
186
187
188
189
190
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
汪志鹏's avatar
汪志鹏 committed
191
        self.act = get_act_fn(projector_hidden_act)
192
193
194
195
196
197
198
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
汪志鹏's avatar
汪志鹏 committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_2(hidden_states)
        return hidden_states


class TarsierProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> TarsierHfConfig:
        return self.ctx.get_hf_config(HfLlavaConfig)

    def get_vision_encoder_info(self) -> VisionEncoderInfo:
        return get_vision_encoder_info(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
215
216
217
218
219
        vision_info = self.get_vision_encoder_info()

        kwargs.setdefault("patch_size", vision_info.get_patch_size())

        return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
汪志鹏's avatar
汪志鹏 committed
220

221
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
汪志鹏's avatar
汪志鹏 committed
222
223
224
225
226
227
228
229
230
231
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
232
        num_projected_patches = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
233
234
235
236
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
237
            hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
238
239
240
        )
        if num_projected_patches <= 0:
            default_size = self.get_image_size_with_most_features()
241
            num_projected_patches_default = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
242
243
244
245
                vision_encoder_info.get_num_image_tokens(
                    image_width=default_size.width,
                    image_height=default_size.height,
                ),
246
                hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
247
248
            )
            if num_projected_patches_default <= 0:
249
                raise ValueError("Could not determine a valid number of image patches.")
汪志鹏's avatar
汪志鹏 committed
250
251
            num_projected_patches = num_projected_patches_default
        num_height_patches = int(math.sqrt(num_projected_patches))
252
        total_image_tokens_for_llm = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        return total_image_tokens_for_llm

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_image_newline_idx(self) -> int:
        return self.get_hf_config().image_newline_idx

    def get_image_new_idx(self) -> int:
        return self.get_hf_config().image_new_idx


_I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo)


class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]):
    pass


class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
296
        out_mm_kwargs: MultiModalKwargsItems,
汪志鹏's avatar
汪志鹏 committed
297
298
299
300
301
302
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index  # The <IMAGE> token ID

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
303
304
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
汪志鹏's avatar
汪志鹏 committed
305
306
307
308
309

            if isinstance(images, ImageEmbeddingItems):
                num_projected_patches = images.get_feature_size(item_idx)
                # This assumes num_projected_patches is a perfect square
                num_height_patches = int(math.sqrt(num_projected_patches))
310
                num_final_image_tokens = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            else:
                image_size = images.get_image_size(item_idx)
                num_final_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_final_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],  # Replace each single <IMAGE> token
                replacement=get_replacement,
            ),
        ]


def init_vision_tower_for_tarsier(
    hf_config: TarsierHfConfig,  # Use the Tarsier specific config protocol
331
    quant_config: QuantizationConfig | None,
汪志鹏's avatar
汪志鹏 committed
332
    *,
333
    require_post_norm: bool | None = None,
汪志鹏's avatar
汪志鹏 committed
334
    prefix: str = "",
335
) -> CLIPVisionModel | SiglipVisionModel:
汪志鹏's avatar
汪志鹏 committed
336
337
338
339
340
341
    vision_config = hf_config.vision_config

    feature_layers = hf_config.vision_feature_layer
    base_num_hidden_layers = vision_config.num_hidden_layers

    if isinstance(feature_layers, int):
342
        num_hidden_layers_to_init = get_layer_index(
343
344
            feature_layers, base_num_hidden_layers
        )
汪志鹏's avatar
汪志鹏 committed
345
346
    elif isinstance(feature_layers, (list, tuple)):
        num_hidden_layers_to_init = max(
347
            get_layer_index(idx, base_num_hidden_layers) for idx in feature_layers
348
        )
汪志鹏's avatar
汪志鹏 committed
349
    else:
350
351
352
        raise TypeError(
            f"vision_layer_feature type: {type(feature_layers)} is not supported"
        )
汪志鹏's avatar
汪志鹏 committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config for Tarsier: {type(vision_config)}"
    raise NotImplementedError(msg)


375
@MULTIMODAL_REGISTRY.register_processor(
376
377
    TarsierMultiModalProcessor,
    info=TarsierProcessingInfo,
378
379
380
    dummy_inputs=TarsierDummyInputsBuilder,
)
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
汪志鹏's avatar
汪志鹏 committed
381
382
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
383
        "gate_up_proj": ["gate_proj", "up_proj"],
汪志鹏's avatar
汪志鹏 committed
384
385
    }

386
    @classmethod
387
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
388
389
390
391
392
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

汪志鹏's avatar
汪志鹏 committed
393
394
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
395

汪志鹏's avatar
汪志鹏 committed
396
397
        config: TarsierHfConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
398

汪志鹏's avatar
汪志鹏 committed
399
400
        self.config = config  # Storing the Tarsier-specific HF config

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = init_vision_tower_for_tarsier(
                config,
                quant_config=quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            projector_bias = getattr(config, "multimodal_projector_bias", True)

            self.multi_modal_projector = TarsierMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
            self.register_buffer(
                "image_newline_idx_tensor",
                torch.tensor([config.image_newline_idx], dtype=torch.long),
                persistent=False,
            )
            self.register_buffer(
                "image_new_idx_tensor",
                torch.tensor([config.image_new_idx], dtype=torch.long),
                persistent=False,
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                # Use text_config from Tarsier's main config
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
汪志鹏's avatar
汪志鹏 committed
436
437

        self.make_empty_intermediate_tensors = (
438
439
            self.language_model.make_empty_intermediate_tensors
        )
汪志鹏's avatar
汪志鹏 committed
440
441

    def _parse_and_validate_image_input(
442
        self, **kwargs: object
443
    ) -> TarsierImageInputs | None:
汪志鹏's avatar
汪志鹏 committed
444
445
446
447
448
449
450
451
452
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return TarsierImagePixelInputs(
                type="pixel_values",
453
                pixel_values=pixel_values,
汪志鹏's avatar
汪志鹏 committed
454
455
456
457
458
            )

        if image_embeds is not None:
            return TarsierImageEmbeddingInputs(
                type="image_embeds",
459
                data=image_embeds,
汪志鹏's avatar
汪志鹏 committed
460
461
462
463
464
465
            )

        raise AssertionError("This line should be unreachable.")

    def _image_pixels_to_features(
        self,
466
467
468
        vision_tower: CLIPVisionModel | SiglipVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
469
        # From vLLM LLaVA, vision tower output handling
470
471
472
473
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
汪志鹏's avatar
汪志鹏 committed
474
475

    def _add_tarsier_split_tokens(
476
477
        self, projected_image_features: torch.Tensor
    ) -> torch.Tensor:
汪志鹏's avatar
汪志鹏 committed
478
479
480
        """
        Implements Tarsier's `add_split_tokens` logic.
        """
481
        num_images, num_projected_patches, embed_dim = projected_image_features.shape
汪志鹏's avatar
汪志鹏 committed
482
483
484
485
486
        num_height_patches = int(math.sqrt(num_projected_patches))
        num_width_patches = num_projected_patches // num_height_patches
        device = projected_image_features.device
        embedding_layer = self.language_model.model.embed_tokens
        image_newline_emb = embedding_layer(
487
488
489
            self.image_newline_idx_tensor.to(device)
        ).squeeze(0)
        image_new_emb = embedding_layer(self.image_new_idx_tensor.to(device)).squeeze(0)
汪志鹏's avatar
汪志鹏 committed
490
491
        try:
            current_image_features_grid = projected_image_features.view(
492
493
                num_images, num_height_patches, num_width_patches, embed_dim
            )
汪志鹏's avatar
汪志鹏 committed
494
495
496
497
498
499
500
501
502
        except RuntimeError as e:
            raise RuntimeError(
                "Cannot reshape projected_image_features"
                f" with shape {projected_image_features.shape} "
                f"to ({num_images}, {num_height_patches},"
                f" {num_width_patches}, {embed_dim}). "
                "Ensure num_projected_patches is compatible"
                " with a grid structure. "
                f"num_projected_patches={num_projected_patches}, "
503
504
                f"derived num_height_patches={num_height_patches}. "
            ) from e
汪志鹏's avatar
汪志鹏 committed
505
506

        image_newline_expanded = image_newline_emb.expand(
507
508
            (num_images, num_height_patches, 1, embed_dim)
        )
汪志鹏's avatar
汪志鹏 committed
509
510
        features_with_newlines = torch.cat(
            [current_image_features_grid, image_newline_expanded],
511
            dim=2,  # Concatenate along width dim
汪志鹏's avatar
汪志鹏 committed
512
        )
513
        new_num_patches_after_newline = num_projected_patches + num_height_patches
汪志鹏's avatar
汪志鹏 committed
514
        features_with_newlines_flat = features_with_newlines.view(
515
516
            num_images, new_num_patches_after_newline, embed_dim
        )
汪志鹏's avatar
汪志鹏 committed
517
518
519
        image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim))
        final_image_features = torch.cat(
            [features_with_newlines_flat, image_new_expanded],
520
            dim=1,  # Concatenate along patch sequence dim
汪志鹏's avatar
汪志鹏 committed
521
522
523
524
525
526
        )
        return final_image_features

    def _process_image_pixels(
        self,
        inputs: TarsierImagePixelInputs,
527
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
528
529
        pixel_values = inputs["pixel_values"]
        image_features_selected = self._image_pixels_to_features(
530
531
            self.vision_tower, pixel_values
        )  # type: ignore
汪志鹏's avatar
汪志鹏 committed
532
        if isinstance(image_features_selected, torch.Tensor):
533
            projected_features = self.multi_modal_projector(image_features_selected)
汪志鹏's avatar
汪志鹏 committed
534
535
536
537
538
            final_features = self._add_tarsier_split_tokens(projected_features)
            return final_features
        else:
            raise TypeError(
                f"_image_pixels_to_features type:"
539
540
                f" {type(image_features_selected)} is not supported"
            )
汪志鹏's avatar
汪志鹏 committed
541
542
543
544

    def _process_image_input(
        self,
        image_input: TarsierImageInputs,
545
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
546
547
548
549
550
        if image_input["type"] == "image_embeds":
            projected_features = image_input["data"]
            if isinstance(projected_features, torch.Tensor):
                return self._add_tarsier_split_tokens(projected_features)
            else:
551
552
553
554
                raise ValueError(
                    "Incorrect type of image_embeds. "
                    f"Got type: {type(projected_features)}. "
                )
汪志鹏's avatar
汪志鹏 committed
555

556
        return self._process_image_pixels(image_input)
汪志鹏's avatar
汪志鹏 committed
557

558
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
汪志鹏's avatar
汪志鹏 committed
559
560
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
561
            return []
汪志鹏's avatar
汪志鹏 committed
562
563
564
565
        return self._process_image_input(image_input)

    def forward(
        self,
566
        input_ids: torch.Tensor | None,
汪志鹏's avatar
汪志鹏 committed
567
        positions: torch.Tensor,
568
569
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
汪志鹏's avatar
汪志鹏 committed
570
        **kwargs: object,
571
    ) -> torch.Tensor | IntermediateTensors:
汪志鹏's avatar
汪志鹏 committed
572
573
        if intermediate_tensors is not None:
            inputs_embeds = None
574

汪志鹏's avatar
汪志鹏 committed
575
576
577
578
        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
579
580
            inputs_embeds=inputs_embeds,
        )
汪志鹏's avatar
汪志鹏 committed
581
582
583
584
585
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
586
    ) -> torch.Tensor | None:
587
        return self.language_model.compute_logits(hidden_states)
汪志鹏's avatar
汪志鹏 committed
588

589
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
汪志鹏's avatar
汪志鹏 committed
590
591
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)