tarsier.py 22.4 KB
Newer Older
汪志鹏's avatar
汪志鹏 committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
汪志鹏's avatar
汪志鹏 committed
3
4
5

import math
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
汪志鹏's avatar
汪志鹏 committed
7
8
9

import torch
import torch.nn as nn
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
汪志鹏's avatar
汪志鹏 committed
16
17
18
from transformers import LlavaConfig as HfLlavaConfig
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
from transformers.models.llava import LlavaProcessor
19
from transformers.processing_utils import ProcessingKwargs, Unpack
汪志鹏's avatar
汪志鹏 committed
20
21
22
23
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput

from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
汪志鹏's avatar
汪志鹏 committed
25
26
27
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import BaseMultiModalProcessorCache
29
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
)
汪志鹏's avatar
汪志鹏 committed
43
44
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
45
from vllm.utils.tensor_schema import TensorSchema, TensorShape
汪志鹏's avatar
汪志鹏 committed
46
47
48
49

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
50
51
52
53
54
55
56
57
58
59
from .utils import (
    AutoWeightsLoader,
    init_vllm_registered_model,
    maybe_prefix,
)
from .vision import (
    VisionEncoderInfo,
    get_num_selected_vision_tokens,
    get_vision_encoder_info,
)
汪志鹏's avatar
汪志鹏 committed
60
61


62
63
64
65
66
67
68
69
class TarsierImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
70

71
72
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
汪志鹏's avatar
汪志鹏 committed
73
74


75
76
77
78
79
80
81
82
class TarsierImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
83

84
85
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
汪志鹏's avatar
汪志鹏 committed
86
87


88
TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageEmbeddingInputs]
汪志鹏's avatar
汪志鹏 committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


class TarsierHfConfig(Protocol):  # Based on the Tarsier's LlavaConfig
    vision_config: Final[PretrainedConfig]
    text_config: Final[PretrainedConfig]  # Added from Tarsier's LlavaConfig
    image_token_index: Final[int]
    vision_feature_select_strategy: Final[str]
    vision_feature_layer: Final[Union[int, list[int]]]
    projector_hidden_act: Final[str]
    image_newline_idx: Final[int]
    image_new_idx: Final[int]
    multimodal_projector_bias: bool = True


class TarsierProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "text_kwargs": {
            "padding": False,
        },
        "images_kwargs": {},
    }


class TarsierProcessor(LlavaProcessor):
    def __call__(
        self,
        images: ImageInput = None,
116
117
118
        text: Union[
            TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
        ] = None,
汪志鹏's avatar
汪志鹏 committed
119
120
121
122
123
        audio=None,
        videos=None,
        **kwargs: Unpack[TarsierProcessorKwargs],
    ) -> BatchFeature:
        if images is None and text is None:
124
            raise ValueError("You have to specify at least one of `images` or `text`.")
汪志鹏's avatar
汪志鹏 committed
125
126
127
128
129
130
131
132

        output_kwargs = self._merge_kwargs(
            TarsierProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )
        if images is not None:
            image_inputs = self.image_processor(
133
134
                images, **output_kwargs["images_kwargs"]
            )
汪志鹏's avatar
汪志鹏 committed
135
136
137
138
139
140
        else:
            image_inputs = {}

        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
141
142
143
            raise ValueError(
                "Invalid input text. Please provide a string, or a list of strings"
            )
汪志鹏's avatar
汪志鹏 committed
144
145
146
147
148
149
150

        # try to expand inputs in processing if we have the necessary parts
        prompt_strings = text
        if image_inputs.get("pixel_values") is not None:
            # Replace the image token with the expanded image token sequence
            pixel_values = image_inputs["pixel_values"]
            height, width = get_image_size(to_numpy_array(pixel_values[0]))
151
152
153
154
155
            num_image_tokens = (
                (height // self.patch_size) * (width // self.patch_size + 1)
                + self.num_additional_image_tokens
                + 1
            )
汪志鹏's avatar
汪志鹏 committed
156
157
158
159
160
            if self.vision_feature_select_strategy == "default":
                num_image_tokens -= 1

            prompt_strings = []
            for sample in text:
161
162
163
                sample = sample.replace(
                    self.image_token, self.image_token * num_image_tokens
                )
汪志鹏's avatar
汪志鹏 committed
164
165
                prompt_strings.append(sample)

166
167
168
169
170
        return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
        text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
        return BatchFeature(
            data={**text_inputs, **image_inputs}, tensor_type=return_tensors
        )
汪志鹏's avatar
汪志鹏 committed
171
172
173


class TarsierMultiModalProjector(nn.Module):
174
175
176
177
178
179
180
181
182
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
汪志鹏's avatar
汪志鹏 committed
183
184
        super().__init__()

185
186
187
188
189
190
191
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
汪志鹏's avatar
汪志鹏 committed
192
        self.act = get_act_fn(projector_hidden_act)
193
194
195
196
197
198
199
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
汪志鹏's avatar
汪志鹏 committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_2(hidden_states)
        return hidden_states


class TarsierProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> TarsierHfConfig:
        return self.ctx.get_hf_config(HfLlavaConfig)

    def get_vision_encoder_info(self) -> VisionEncoderInfo:
        return get_vision_encoder_info(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
216
217
218
219
220
        vision_info = self.get_vision_encoder_info()

        kwargs.setdefault("patch_size", vision_info.get_patch_size())

        return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
汪志鹏's avatar
汪志鹏 committed
221
222
223
224
225
226
227
228
229
230
231
232

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
233
        num_projected_patches = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
234
235
236
237
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
238
            hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
239
240
241
        )
        if num_projected_patches <= 0:
            default_size = self.get_image_size_with_most_features()
242
            num_projected_patches_default = get_num_selected_vision_tokens(
汪志鹏's avatar
汪志鹏 committed
243
244
245
246
                vision_encoder_info.get_num_image_tokens(
                    image_width=default_size.width,
                    image_height=default_size.height,
                ),
247
                hf_config.vision_feature_select_strategy,
汪志鹏's avatar
汪志鹏 committed
248
249
            )
            if num_projected_patches_default <= 0:
250
                raise ValueError("Could not determine a valid number of image patches.")
汪志鹏's avatar
汪志鹏 committed
251
252
            num_projected_patches = num_projected_patches_default
        num_height_patches = int(math.sqrt(num_projected_patches))
253
        total_image_tokens_for_llm = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        return total_image_tokens_for_llm

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_image_newline_idx(self) -> int:
        return self.get_hf_config().image_newline_idx

    def get_image_new_idx(self) -> int:
        return self.get_hf_config().image_new_idx


_I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo)


class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]):
    pass


class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
297
        out_mm_kwargs: MultiModalKwargsItems,
汪志鹏's avatar
汪志鹏 committed
298
299
300
301
302
303
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index  # The <IMAGE> token ID

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
304
305
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
汪志鹏's avatar
汪志鹏 committed
306
307
308
309
310

            if isinstance(images, ImageEmbeddingItems):
                num_projected_patches = images.get_feature_size(item_idx)
                # This assumes num_projected_patches is a perfect square
                num_height_patches = int(math.sqrt(num_projected_patches))
311
                num_final_image_tokens = num_projected_patches + num_height_patches + 1
汪志鹏's avatar
汪志鹏 committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            else:
                image_size = images.get_image_size(item_idx)
                num_final_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_final_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],  # Replace each single <IMAGE> token
                replacement=get_replacement,
            ),
        ]


330
def _build_tarsier_hf_info(ctx: InputProcessingContext) -> TarsierProcessingInfo:
汪志鹏's avatar
汪志鹏 committed
331
332
333
334
335
336
337
    return TarsierProcessingInfo(ctx)


def _build_tarsier_hf_processor(
    info: _I_Tarsier,
    dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
    *,
338
    cache: Optional[BaseMultiModalProcessorCache] = None,
汪志鹏's avatar
汪志鹏 committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
) -> BaseMultiModalProcessor:
    if isinstance(info, TarsierProcessingInfo):
        return TarsierMultiModalProcessor(
            info,
            dummy_inputs,
            cache=cache,
        )
    raise NotImplementedError(type(info))


def init_vision_tower_for_tarsier(
    hf_config: TarsierHfConfig,  # Use the Tarsier specific config protocol
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
    prefix: str = "",
) -> Union[CLIPVisionModel, SiglipVisionModel]:
    vision_config = hf_config.vision_config

    feature_layers = hf_config.vision_feature_layer
    base_num_hidden_layers = vision_config.num_hidden_layers

361
    def _get_layer_index(feature_layer_index: int, num_hidden_layers_total: int) -> int:
汪志鹏's avatar
汪志鹏 committed
362
363
364
365
366
        if feature_layer_index < 0:
            return num_hidden_layers_total + feature_layer_index + 1
        return feature_layer_index

    if isinstance(feature_layers, int):
367
368
369
        num_hidden_layers_to_init = _get_layer_index(
            feature_layers, base_num_hidden_layers
        )
汪志鹏's avatar
汪志鹏 committed
370
371
    elif isinstance(feature_layers, (list, tuple)):
        num_hidden_layers_to_init = max(
372
373
            _get_layer_index(idx, base_num_hidden_layers) for idx in feature_layers
        )
汪志鹏's avatar
汪志鹏 committed
374
    else:
375
376
377
        raise TypeError(
            f"vision_layer_feature type: {type(feature_layers)} is not supported"
        )
汪志鹏's avatar
汪志鹏 committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_to_init,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config for Tarsier: {type(vision_config)}"
    raise NotImplementedError(msg)


400
401
402
403
404
405
@MULTIMODAL_REGISTRY.register_processor(
    _build_tarsier_hf_processor,
    info=_build_tarsier_hf_info,
    dummy_inputs=TarsierDummyInputsBuilder,
)
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
406
407
    merge_by_field_config = True

汪志鹏's avatar
汪志鹏 committed
408
409
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
410
        "gate_up_proj": ["gate_proj", "up_proj"],
汪志鹏's avatar
汪志鹏 committed
411
412
    }

413
414
415
416
417
418
419
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

汪志鹏's avatar
汪志鹏 committed
420
421
422
423
424
425
426
427
428
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config: TarsierHfConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config  # Storing the Tarsier-specific HF config
        self.vision_tower = init_vision_tower_for_tarsier(
            config,
            quant_config,
            require_post_norm=False,
429
430
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
汪志鹏's avatar
汪志鹏 committed
431
432
433
434
435
436
437
438
        projector_bias = getattr(config, "multimodal_projector_bias", True)

        self.multi_modal_projector = TarsierMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=projector_bias,
            quant_config=quant_config,
439
440
            prefix=maybe_prefix(prefix, "multi_modal_projector"),
        )
汪志鹏's avatar
汪志鹏 committed
441
442
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
443
            hf_config=config.text_config,  # Use text_config from Tarsier's main config
汪志鹏's avatar
汪志鹏 committed
444
445
            prefix=maybe_prefix(prefix, "language_model"),
        )
446
447
448
449
450
451
452
453
454
455
        self.register_buffer(
            "image_newline_idx_tensor",
            torch.tensor([config.image_newline_idx], dtype=torch.long),
            persistent=False,
        )
        self.register_buffer(
            "image_new_idx_tensor",
            torch.tensor([config.image_new_idx], dtype=torch.long),
            persistent=False,
        )
汪志鹏's avatar
汪志鹏 committed
456
457

        self.make_empty_intermediate_tensors = (
458
459
            self.language_model.make_empty_intermediate_tensors
        )
汪志鹏's avatar
汪志鹏 committed
460
461

    def _parse_and_validate_image_input(
462
463
        self, **kwargs: object
    ) -> Optional[TarsierImageInputs]:
汪志鹏's avatar
汪志鹏 committed
464
465
466
467
468
469
470
471
472
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return TarsierImagePixelInputs(
                type="pixel_values",
473
                pixel_values=pixel_values,
汪志鹏's avatar
汪志鹏 committed
474
475
476
477
478
            )

        if image_embeds is not None:
            return TarsierImageEmbeddingInputs(
                type="image_embeds",
479
                data=image_embeds,
汪志鹏's avatar
汪志鹏 committed
480
481
482
483
484
485
486
487
488
489
            )

        raise AssertionError("This line should be unreachable.")

    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
        # From vLLM LLaVA, vision tower output handling
490
491
492
493
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
汪志鹏's avatar
汪志鹏 committed
494
495

    def _add_tarsier_split_tokens(
496
497
        self, projected_image_features: torch.Tensor
    ) -> torch.Tensor:
汪志鹏's avatar
汪志鹏 committed
498
499
500
        """
        Implements Tarsier's `add_split_tokens` logic.
        """
501
        num_images, num_projected_patches, embed_dim = projected_image_features.shape
汪志鹏's avatar
汪志鹏 committed
502
503
504
505
506
        num_height_patches = int(math.sqrt(num_projected_patches))
        num_width_patches = num_projected_patches // num_height_patches
        device = projected_image_features.device
        embedding_layer = self.language_model.model.embed_tokens
        image_newline_emb = embedding_layer(
507
508
509
            self.image_newline_idx_tensor.to(device)
        ).squeeze(0)
        image_new_emb = embedding_layer(self.image_new_idx_tensor.to(device)).squeeze(0)
汪志鹏's avatar
汪志鹏 committed
510
511
        try:
            current_image_features_grid = projected_image_features.view(
512
513
                num_images, num_height_patches, num_width_patches, embed_dim
            )
汪志鹏's avatar
汪志鹏 committed
514
515
516
517
518
519
520
521
522
        except RuntimeError as e:
            raise RuntimeError(
                "Cannot reshape projected_image_features"
                f" with shape {projected_image_features.shape} "
                f"to ({num_images}, {num_height_patches},"
                f" {num_width_patches}, {embed_dim}). "
                "Ensure num_projected_patches is compatible"
                " with a grid structure. "
                f"num_projected_patches={num_projected_patches}, "
523
524
                f"derived num_height_patches={num_height_patches}. "
            ) from e
汪志鹏's avatar
汪志鹏 committed
525
526

        image_newline_expanded = image_newline_emb.expand(
527
528
            (num_images, num_height_patches, 1, embed_dim)
        )
汪志鹏's avatar
汪志鹏 committed
529
530
        features_with_newlines = torch.cat(
            [current_image_features_grid, image_newline_expanded],
531
            dim=2,  # Concatenate along width dim
汪志鹏's avatar
汪志鹏 committed
532
        )
533
        new_num_patches_after_newline = num_projected_patches + num_height_patches
汪志鹏's avatar
汪志鹏 committed
534
        features_with_newlines_flat = features_with_newlines.view(
535
536
            num_images, new_num_patches_after_newline, embed_dim
        )
汪志鹏's avatar
汪志鹏 committed
537
538
539
        image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim))
        final_image_features = torch.cat(
            [features_with_newlines_flat, image_new_expanded],
540
            dim=1,  # Concatenate along patch sequence dim
汪志鹏's avatar
汪志鹏 committed
541
542
543
544
545
546
547
548
549
550
        )
        return final_image_features

    def _process_image_pixels(
        self,
        inputs: TarsierImagePixelInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
        assert self.vision_tower is not None
        pixel_values = inputs["pixel_values"]
        image_features_selected = self._image_pixels_to_features(
551
552
            self.vision_tower, pixel_values
        )  # type: ignore
汪志鹏's avatar
汪志鹏 committed
553
        if isinstance(image_features_selected, torch.Tensor):
554
            projected_features = self.multi_modal_projector(image_features_selected)
汪志鹏's avatar
汪志鹏 committed
555
556
557
558
559
            final_features = self._add_tarsier_split_tokens(projected_features)
            return final_features
        else:
            raise TypeError(
                f"_image_pixels_to_features type:"
560
561
                f" {type(image_features_selected)} is not supported"
            )
汪志鹏's avatar
汪志鹏 committed
562
563
564
565
566
567
568
569
570
571

    def _process_image_input(
        self,
        image_input: TarsierImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
        if image_input["type"] == "image_embeds":
            projected_features = image_input["data"]
            if isinstance(projected_features, torch.Tensor):
                return self._add_tarsier_split_tokens(projected_features)
            else:
572
573
574
575
                raise ValueError(
                    "Incorrect type of image_embeds. "
                    f"Got type: {type(projected_features)}. "
                )
汪志鹏's avatar
汪志鹏 committed
576
577
578
579
580
581
        assert self.vision_tower is not None
        return self._process_image_pixels(image_input)

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

582
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
汪志鹏's avatar
汪志鹏 committed
583
584
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
585
            return []
汪志鹏's avatar
汪志鹏 committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        return self._process_image_input(image_input)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
600
601
602
603
604
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                vision_embeddings,
                is_multimodal=input_ids == self.config.image_token_index,
            )
汪志鹏's avatar
汪志鹏 committed
605
606
607
608
609
            input_ids = None
        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
610
611
            inputs_embeds=inputs_embeds,
        )
汪志鹏's avatar
汪志鹏 committed
612
613
614
615
616
617
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
618
        return self.language_model.compute_logits(hidden_states)
汪志鹏's avatar
汪志鹏 committed
619

620
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
汪志鹏's avatar
汪志鹏 committed
621
622
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)