tarsier.py 22.1 KB
Newer Older
汪志鹏's avatar
汪志鹏 committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
汪志鹏's avatar
汪志鹏 committed
3
4
5

import math
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
汪志鹏's avatar
汪志鹏 committed
7
8
9

import torch
import torch.nn as nn
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
汪志鹏's avatar
汪志鹏 committed
16
17
18
from transformers import LlavaConfig as HfLlavaConfig
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
from transformers.models.llava import LlavaProcessor
19
from transformers.processing_utils import ProcessingKwargs, Unpack
汪志鹏's avatar
汪志鹏 committed
20
21
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput

22
from vllm.config import MultiModalConfig, VllmConfig
汪志鹏's avatar
汪志鹏 committed
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
汪志鹏's avatar
汪志鹏 committed
25
26
27
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import BaseMultiModalProcessorCache
29
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
30
31
32
33
34
35
36
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
37
    BaseDummyInputsBuilder,
38
39
40
41
42
43
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
)
汪志鹏's avatar
汪志鹏 committed
44
from vllm.sequence import IntermediateTensors
45
from vllm.utils.tensor_schema import TensorSchema, TensorShape
汪志鹏's avatar
汪志鹏 committed
46
47
48
49

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
50
51
52
53
54
55
from .utils import (
    AutoWeightsLoader,
    get_layer_index,
    init_vllm_registered_model,
    maybe_prefix,
)
56
57
58
59
60
from .vision import (
    VisionEncoderInfo,
    get_num_selected_vision_tokens,
    get_vision_encoder_info,
)
汪志鹏's avatar
汪志鹏 committed
61
62


63
64
65
66
67
68
69
70
class TarsierImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
71

72
73
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
汪志鹏's avatar
汪志鹏 committed
74
75


76
77
78
79
80
81
82
83
class TarsierImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
84

85
86
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
汪志鹏's avatar
汪志鹏 committed
87
88


89
TarsierImageInputs: TypeAlias = TarsierImagePixelInputs | TarsierImageEmbeddingInputs
汪志鹏's avatar
汪志鹏 committed
90
91
92
93
94
95
96


class TarsierHfConfig(Protocol):  # Based on the Tarsier's LlavaConfig
    vision_config: Final[PretrainedConfig]
    text_config: Final[PretrainedConfig]  # Added from Tarsier's LlavaConfig
    image_token_index: Final[int]
    vision_feature_select_strategy: Final[str]
97
    vision_feature_layer: Final[int | list[int]]
汪志鹏's avatar
汪志鹏 committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    projector_hidden_act: Final[str]
    image_newline_idx: Final[int]
    image_new_idx: Final[int]
    multimodal_projector_bias: bool = True


class TarsierProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "text_kwargs": {
            "padding": False,
        },
        "images_kwargs": {},
    }


class TarsierProcessor(LlavaProcessor):
    def __call__(
        self,
        images: ImageInput = None,
117
118
119
120
        text: TextInput
        | PreTokenizedInput
        | list[TextInput]
        | list[PreTokenizedInput] = None,
汪志鹏's avatar
汪志鹏 committed
121
122
123
124
125
        audio=None,
        videos=None,
        **kwargs: Unpack[TarsierProcessorKwargs],
    ) -> BatchFeature:
        if images is None and text is None:
126
            raise ValueError("You have to specify at least one of `images` or `text`.")
汪志鹏's avatar
汪志鹏 committed
127
128
129
130
131
132
133
134

        output_kwargs = self._merge_kwargs(
            TarsierProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )
        if images is not None:
            image_inputs = self.image_processor(
135
136
                images, **output_kwargs["images_kwargs"]
            )
汪志鹏's avatar
汪志鹏 committed
137
138
139
140
141
142
        else:
            image_inputs = {}

        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
143
144
145
            raise ValueError(
                "Invalid input text. Please provide a string, or a list of strings"
            )
汪志鹏's avatar
汪志鹏 committed
146
147
148
149
150
151
152

        # try to expand inputs in processing if we have the necessary parts
        prompt_strings = text
        if image_inputs.get("pixel_values") is not None:
            # Replace the image token with the expanded image token sequence
            pixel_values = image_inputs["pixel_values"]
            height, width = get_image_size(to_numpy_array(pixel_values[0]))
153
154
155
156
157
            num_image_tokens = (
                (height // self.patch_size) * (width // self.patch_size + 1)
                + self.num_additional_image_tokens
                + 1
            )
汪志鹏's avatar
汪志鹏 committed
158
159
160
161
162
            if self.vision_feature_select_strategy == "default":
                num_image_tokens -= 1

            prompt_strings = []
            for sample in text:
163
164
165
                sample = sample.replace(
                    self.image_token, self.image_token * num_image_tokens
                )
汪志鹏's avatar
汪志鹏 committed
166
167
                prompt_strings.append(sample)

168
169
170
171
172
        return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
        text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
        return BatchFeature(
            data={**text_inputs, **image_inputs}, tensor_type=return_tensors
        )
汪志鹏's avatar
汪志鹏 committed
173
174
175


class TarsierMultiModalProjector(nn.Module):
176
177
178
179
180
181
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
182
        quant_config: QuantizationConfig | None = None,
183
184
        prefix: str = "",
    ):
汪志鹏's avatar
汪志鹏 committed
185
186
        super().__init__()

187
188
189
190
191
192
193
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
汪志鹏's avatar
汪志鹏 committed
194
        self.act = get_act_fn(projector_hidden_act)
195
196
197
198
199
200
201
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
汪志鹏's avatar
汪志鹏 committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_2(hidden_states)
        return hidden_states


class TarsierProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> TarsierHfConfig:
        return self.ctx.get_hf_config(HfLlavaConfig)

    def get_vision_encoder_info(self) -> VisionEncoderInfo:
        return get_vision_encoder_info(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
218
219
220
221
222
        vision_info = self.get_vision_encoder_info()

        kwargs.setdefault("patch_size", vision_info.get_patch_size())

        return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
汪志鹏's avatar
汪志鹏 committed
223

224
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
汪志鹏's avatar
汪志鹏 committed
225
226
227
228
229
230
231
232
233
234
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
235
        num_projected_patches = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
236
237
238
239
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
240
            hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
241
242
243
        )
        if num_projected_patches <= 0:
            default_size = self.get_image_size_with_most_features()
244
            num_projected_patches_default = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
245
246
247
248
                vision_encoder_info.get_num_image_tokens(
                    image_width=default_size.width,
                    image_height=default_size.height,
                ),
249
                hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
250
251
            )
            if num_projected_patches_default <= 0:
252
                raise ValueError("Could not determine a valid number of image patches.")
汪志鹏's avatar
汪志鹏 committed
253
254
            num_projected_patches = num_projected_patches_default
        num_height_patches = int(math.sqrt(num_projected_patches))
255
        total_image_tokens_for_llm = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        return total_image_tokens_for_llm

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_image_newline_idx(self) -> int:
        return self.get_hf_config().image_newline_idx

    def get_image_new_idx(self) -> int:
        return self.get_hf_config().image_new_idx


_I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo)


class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]):
    pass


class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
299
        out_mm_kwargs: MultiModalKwargsItems,
汪志鹏's avatar
汪志鹏 committed
300
301
302
303
304
305
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index  # The <IMAGE> token ID

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
306
307
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
汪志鹏's avatar
汪志鹏 committed
308
309
310
311
312

            if isinstance(images, ImageEmbeddingItems):
                num_projected_patches = images.get_feature_size(item_idx)
                # This assumes num_projected_patches is a perfect square
                num_height_patches = int(math.sqrt(num_projected_patches))
313
                num_final_image_tokens = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
            else:
                image_size = images.get_image_size(item_idx)
                num_final_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_final_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],  # Replace each single <IMAGE> token
                replacement=get_replacement,
            ),
        ]


332
def _build_tarsier_hf_info(ctx: InputProcessingContext) -> TarsierProcessingInfo:
汪志鹏's avatar
汪志鹏 committed
333
334
335
336
337
338
339
    return TarsierProcessingInfo(ctx)


def _build_tarsier_hf_processor(
    info: _I_Tarsier,
    dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
    *,
340
    cache: BaseMultiModalProcessorCache | None = None,
汪志鹏's avatar
汪志鹏 committed
341
342
343
344
345
346
347
348
349
350
351
352
) -> BaseMultiModalProcessor:
    if isinstance(info, TarsierProcessingInfo):
        return TarsierMultiModalProcessor(
            info,
            dummy_inputs,
            cache=cache,
        )
    raise NotImplementedError(type(info))


def init_vision_tower_for_tarsier(
    hf_config: TarsierHfConfig,  # Use the Tarsier specific config protocol
353
    quant_config: QuantizationConfig | None,
354
    multimodal_config: MultiModalConfig | None,
汪志鹏's avatar
汪志鹏 committed
355
    *,
356
    require_post_norm: bool | None = None,
汪志鹏's avatar
汪志鹏 committed
357
    prefix: str = "",
358
) -> CLIPVisionModel | SiglipVisionModel:
汪志鹏's avatar
汪志鹏 committed
359
360
361
362
363
364
    vision_config = hf_config.vision_config

    feature_layers = hf_config.vision_feature_layer
    base_num_hidden_layers = vision_config.num_hidden_layers

    if isinstance(feature_layers, int):
365
        num_hidden_layers_to_init = get_layer_index(
366
367
            feature_layers, base_num_hidden_layers
        )
汪志鹏's avatar
汪志鹏 committed
368
369
    elif isinstance(feature_layers, (list, tuple)):
        num_hidden_layers_to_init = max(
370
            get_layer_index(idx, base_num_hidden_layers) for idx in feature_layers
371
        )
汪志鹏's avatar
汪志鹏 committed
372
    else:
373
374
375
        raise TypeError(
            f"vision_layer_feature type: {type(feature_layers)} is not supported"
        )
汪志鹏's avatar
汪志鹏 committed
376
377
378
379
380

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
381
            multimodal_config=multimodal_config,
汪志鹏's avatar
汪志鹏 committed
382
383
384
385
386
387
388
389
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
390
            multimodal_config=multimodal_config,
汪志鹏's avatar
汪志鹏 committed
391
392
393
394
395
396
397
398
399
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config for Tarsier: {type(vision_config)}"
    raise NotImplementedError(msg)


400
401
402
403
404
405
@MULTIMODAL_REGISTRY.register_processor(
    _build_tarsier_hf_processor,
    info=_build_tarsier_hf_info,
    dummy_inputs=TarsierDummyInputsBuilder,
)
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
汪志鹏's avatar
汪志鹏 committed
406
407
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
408
        "gate_up_proj": ["gate_proj", "up_proj"],
汪志鹏's avatar
汪志鹏 committed
409
410
    }

411
    @classmethod
412
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
413
414
415
416
417
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

汪志鹏's avatar
汪志鹏 committed
418
419
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
420

汪志鹏's avatar
汪志鹏 committed
421
422
        config: TarsierHfConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
423
424
        multimodal_config = vllm_config.model_config.multimodal_config

汪志鹏's avatar
汪志鹏 committed
425
426
        self.config = config  # Storing the Tarsier-specific HF config

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = init_vision_tower_for_tarsier(
                config,
                quant_config=quant_config,
                multimodal_config=multimodal_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            projector_bias = getattr(config, "multimodal_projector_bias", True)

            self.multi_modal_projector = TarsierMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
            self.register_buffer(
                "image_newline_idx_tensor",
                torch.tensor([config.image_newline_idx], dtype=torch.long),
                persistent=False,
            )
            self.register_buffer(
                "image_new_idx_tensor",
                torch.tensor([config.image_new_idx], dtype=torch.long),
                persistent=False,
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                # Use text_config from Tarsier's main config
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
汪志鹏's avatar
汪志鹏 committed
463
464

        self.make_empty_intermediate_tensors = (
465
466
            self.language_model.make_empty_intermediate_tensors
        )
汪志鹏's avatar
汪志鹏 committed
467
468

    def _parse_and_validate_image_input(
469
        self, **kwargs: object
470
    ) -> TarsierImageInputs | None:
汪志鹏's avatar
汪志鹏 committed
471
472
473
474
475
476
477
478
479
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return TarsierImagePixelInputs(
                type="pixel_values",
480
                pixel_values=pixel_values,
汪志鹏's avatar
汪志鹏 committed
481
482
483
484
485
            )

        if image_embeds is not None:
            return TarsierImageEmbeddingInputs(
                type="image_embeds",
486
                data=image_embeds,
汪志鹏's avatar
汪志鹏 committed
487
488
489
490
491
492
            )

        raise AssertionError("This line should be unreachable.")

    def _image_pixels_to_features(
        self,
493
494
495
        vision_tower: CLIPVisionModel | SiglipVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
496
        # From vLLM LLaVA, vision tower output handling
497
498
499
500
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
汪志鹏's avatar
汪志鹏 committed
501
502

    def _add_tarsier_split_tokens(
503
504
        self, projected_image_features: torch.Tensor
    ) -> torch.Tensor:
汪志鹏's avatar
汪志鹏 committed
505
506
507
        """
        Implements Tarsier's `add_split_tokens` logic.
        """
508
        num_images, num_projected_patches, embed_dim = projected_image_features.shape
汪志鹏's avatar
汪志鹏 committed
509
510
511
512
513
        num_height_patches = int(math.sqrt(num_projected_patches))
        num_width_patches = num_projected_patches // num_height_patches
        device = projected_image_features.device
        embedding_layer = self.language_model.model.embed_tokens
        image_newline_emb = embedding_layer(
514
515
516
            self.image_newline_idx_tensor.to(device)
        ).squeeze(0)
        image_new_emb = embedding_layer(self.image_new_idx_tensor.to(device)).squeeze(0)
汪志鹏's avatar
汪志鹏 committed
517
518
        try:
            current_image_features_grid = projected_image_features.view(
519
520
                num_images, num_height_patches, num_width_patches, embed_dim
            )
汪志鹏's avatar
汪志鹏 committed
521
522
523
524
525
526
527
528
529
        except RuntimeError as e:
            raise RuntimeError(
                "Cannot reshape projected_image_features"
                f" with shape {projected_image_features.shape} "
                f"to ({num_images}, {num_height_patches},"
                f" {num_width_patches}, {embed_dim}). "
                "Ensure num_projected_patches is compatible"
                " with a grid structure. "
                f"num_projected_patches={num_projected_patches}, "
530
531
                f"derived num_height_patches={num_height_patches}. "
            ) from e
汪志鹏's avatar
汪志鹏 committed
532
533

        image_newline_expanded = image_newline_emb.expand(
534
535
            (num_images, num_height_patches, 1, embed_dim)
        )
汪志鹏's avatar
汪志鹏 committed
536
537
        features_with_newlines = torch.cat(
            [current_image_features_grid, image_newline_expanded],
538
            dim=2,  # Concatenate along width dim
汪志鹏's avatar
汪志鹏 committed
539
        )
540
        new_num_patches_after_newline = num_projected_patches + num_height_patches
汪志鹏's avatar
汪志鹏 committed
541
        features_with_newlines_flat = features_with_newlines.view(
542
543
            num_images, new_num_patches_after_newline, embed_dim
        )
汪志鹏's avatar
汪志鹏 committed
544
545
546
        image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim))
        final_image_features = torch.cat(
            [features_with_newlines_flat, image_new_expanded],
547
            dim=1,  # Concatenate along patch sequence dim
汪志鹏's avatar
汪志鹏 committed
548
549
550
551
552
553
        )
        return final_image_features

    def _process_image_pixels(
        self,
        inputs: TarsierImagePixelInputs,
554
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
555
556
        pixel_values = inputs["pixel_values"]
        image_features_selected = self._image_pixels_to_features(
557
558
            self.vision_tower, pixel_values
        )  # type: ignore
汪志鹏's avatar
汪志鹏 committed
559
        if isinstance(image_features_selected, torch.Tensor):
560
            projected_features = self.multi_modal_projector(image_features_selected)
汪志鹏's avatar
汪志鹏 committed
561
562
563
564
565
            final_features = self._add_tarsier_split_tokens(projected_features)
            return final_features
        else:
            raise TypeError(
                f"_image_pixels_to_features type:"
566
567
                f" {type(image_features_selected)} is not supported"
            )
汪志鹏's avatar
汪志鹏 committed
568
569
570
571

    def _process_image_input(
        self,
        image_input: TarsierImageInputs,
572
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
汪志鹏's avatar
汪志鹏 committed
573
574
575
576
577
        if image_input["type"] == "image_embeds":
            projected_features = image_input["data"]
            if isinstance(projected_features, torch.Tensor):
                return self._add_tarsier_split_tokens(projected_features)
            else:
578
579
580
581
                raise ValueError(
                    "Incorrect type of image_embeds. "
                    f"Got type: {type(projected_features)}. "
                )
汪志鹏's avatar
汪志鹏 committed
582

583
        return self._process_image_pixels(image_input)
汪志鹏's avatar
汪志鹏 committed
584

585
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
汪志鹏's avatar
汪志鹏 committed
586
587
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
588
            return []
汪志鹏's avatar
汪志鹏 committed
589
590
591
592
593
594
        return self._process_image_input(image_input)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
595
596
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
汪志鹏's avatar
汪志鹏 committed
597
        **kwargs: object,
598
    ) -> torch.Tensor | IntermediateTensors:
汪志鹏's avatar
汪志鹏 committed
599
600
        if intermediate_tensors is not None:
            inputs_embeds = None
601

汪志鹏's avatar
汪志鹏 committed
602
603
604
605
        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
606
607
            inputs_embeds=inputs_embeds,
        )
汪志鹏's avatar
汪志鹏 committed
608
609
610
611
612
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
613
    ) -> torch.Tensor | None:
614
        return self.language_model.compute_logits(hidden_states)
汪志鹏's avatar
汪志鹏 committed
615

616
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
汪志鹏's avatar
汪志鹏 committed
617
618
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)