ovis2_5.py 23.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""PyTorch Ovis model."""

5
6
from collections.abc import Iterable, Mapping
from functools import partial
7
from typing import Literal, Optional, TypedDict, Union
8
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig

from vllm.config import VllmConfig
14
from vllm.config.multimodal import BaseDummyOptions
15
from vllm.model_executor.layers.linear import ReplicatedLinear
16
from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding
18
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
19
20
21
22
23
24
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
31
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
32
33
34
35
36
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
)
37
38
39
40
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor

41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
INDICATOR_IDS = [-301, -302, -303, -304]

IMAGE_PAD_TOKEN_MAP = {
    "gemma2": "<unused0>",
    "llama": "<|reserved_special_token_0|>",
    "qwen2": "<|image_pad|>",
    "qwen3": "<|image_pad|>",
}
IMAGE_PAD_TOKEN_ID_MAP = {
    "gemma2": 7,
    "llama": 128002,
    "qwen2": 151655,
    "qwen3": 151655,
}


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class OvisVideoPatchInputs(TypedDict):
    type: Literal["video_patches"]
    flat_data: torch.Tensor
    """
    Shape:
    `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
    """

    indicator_tokens: torch.Tensor
    """
    Shape:
    `(batch_size * (num_patches + 1))`
    """

    patches_per_image: list[int]
    """
    List of number of total patches for each frame in the video.
    This is used to restore the first two dimensions of `flat_data`.
    """


82
def _ovis2_5_field_config():
83
84
85
86
87
88
89
90
    return dict(
        pixel_values=MultiModalFieldConfig.batched("image"),
        grids=MultiModalFieldConfig.batched("image"),
        indicator_tokens=MultiModalFieldConfig.batched("image"),
        video_pixel_values=MultiModalFieldConfig.batched("video"),
        video_indicator_tokens=MultiModalFieldConfig.batched("video"),
        video_grids=MultiModalFieldConfig.batched("video"),
    )
91
92
93
94
95
96
97
98
99
100
101
102
103


class VisualTokenizer(torch.nn.Module):
    """
    VIT
    """

    def __init__(
        self,
        config: PretrainedConfig,
        visual_vocab_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
104
        use_data_parallel: bool = False,
105
106
107
108
109
110
111
    ):
        super().__init__()
        self.config = config
        self.vit = self._init_backbone(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.vit",
112
            use_data_parallel=use_data_parallel,
113
114
115
116
117
118
119
120
121
        )
        # reserved tokens for INDICATOR_IDS
        head_dim = visual_vocab_size - len(INDICATOR_IDS)
        self.head = torch.nn.Sequential(
            ReplicatedLinear(
                self.config.hidden_size * self.config.hidden_stride**2,
                head_dim,
                bias=False,
                return_bias=False,
122
123
124
            ),
            torch.nn.LayerNorm(head_dim),
        )
125
126
127
128
129
130

    def _init_backbone(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
131
        use_data_parallel: bool = False,
132
133
134
    ):
        model_type = config.model_type
        if model_type == "siglip2_navit":
135
136
137
138
139
140
141
            return Siglip2NavitModel(
                config=config,
                quant_config=quant_config,
                prefix=prefix,
                use_data_parallel=use_data_parallel,
            )
        raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
142
143

    @property
144
    def dtype(self) -> torch.dtype:
145
146
147
        return next(self.head.parameters()).dtype

    @property
148
    def device(self) -> torch.device:
149
150
        return next(self.head.parameters()).device

151
    def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
152
        tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
153
154
        return tokens

155
156
157
    def encode(
        self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
    ) -> torch.Tensor:
158
        features = self.vit(pixel_values, grid_thws)
159
160
        # refer to qwen2.5-vl patchmerger
        seq_len, _ = features.shape
161
        features = features.reshape(seq_len // (self.config.hidden_stride**2), -1)
162
163
164

        return features

165
166
167
    def forward(
        self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
    ) -> torch.Tensor:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        features = self.encode(pixel_values, grid_thws)
        logits = self.head(features)
        tokens = self.tokenize(logits)
        # tokens' shape is [#Token, VocabSize-4],
        # so padding with [#Token, 4], after which,
        # tokens' shape should become [#Token, VocabSize];
        tokens = torch.nn.functional.pad(
            tokens,
            (0, len(INDICATOR_IDS)),
            mode="constant",
            value=0,
        )
        return tokens


class Ovis2_5ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_hf_processor(self, **kwargs):
        vit_config = self.get_hf_config().vit_config
        return self.ctx.get_hf_processor(
            Ovis2_5Processor,
            image_pad_token=self.get_image_pad_token(),
            patch_size=vit_config.patch_size,
            hidden_stride=vit_config.hidden_stride,
            temporal_patch_size=vit_config.temporal_patch_size,
        )

    def get_image_pad_token(self) -> str:
        hf_text_config = self.get_hf_config().get_text_config()
        text_model_type = hf_text_config.model_type
        return IMAGE_PAD_TOKEN_MAP.get(text_model_type)

    def get_image_processor(self) -> BaseImageProcessor:
        return self.get_hf_processor().image_processor  # type: ignore

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": 1}

    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code
        # TODO(myselvess): Be adjusted based on the max_pixels
        return ImageSize(width=1792, height=1792)

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vit_config = hf_config.vit_config
        patch_size = vit_config.patch_size
        temporal_patch_size = vit_config.temporal_patch_size
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + (-num_frames % temporal_patch_size)
        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = image_height // patch_size
        grid_w = image_width // patch_size
        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches
        return num_vision_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
236
237
238
        return self.get_num_image_tokens(
            image_width=target_width, image_height=target_height
        )
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        num_frames = 0
        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )
            if next_max_tokens > max_tokens:
                break
            num_frames = next_num_frames
        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
        max_image_tokens = self.get_max_image_tokens() * max_images
264
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
265
266
267
268
269
270
271
272
273
274
275
        max_frames_per_video = max_total_frames // max(max_videos, 1)
        return max(max_frames_per_video, 1)

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
        image_processor: Optional[BaseImageProcessor],
    ) -> int:
276
277
278
        num_video_tokens = self.get_num_image_tokens(
            image_width=image_width, image_height=image_height, num_frames=num_frames
        )
279
280
281
282
283
284
285
286
287
288
289
        return num_video_tokens

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
290
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            image_processor=None,
        )


class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
        return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
305
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
306
307
308
309
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

310
311
312
313
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
314
315
316
317

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

318
        mm_data = {
319
320
321
322
323
324
325
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
326
327
328
329
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
330
                overrides=video_overrides,
331
            ),
332
333
334
335
        }
        return mm_data


336
class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]):
337
338
339
340
341
    def visual_indicators_to_visual_tokens(
        self,
        visual_indicators: list[int],
    ) -> list[int]:
        """
342
        Filter image indicators placeholders and convert them to corresponding
343
344
345
346
347
348
        tokens in visual tokenizer.
        """
        hf_config = self.info.get_hf_config()
        vte_vocab_size = hf_config.visual_vocab_size
        return [
            vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1
349
350
            for x in visual_indicators
            if x < -300
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        ]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        hf_processor = self.info.get_hf_processor()

        if "videos" in mm_data:
            visual_indicators = [
                hf_processor.construct_visual_indicators((1, 1, 1), True)
                for grid in processed_outputs["video_grids"]
            ]
            indicator_tokens = [
                self.visual_indicators_to_visual_tokens(indicator)
                for indicator in visual_indicators
            ]
            processed_outputs["video_indicator_tokens"] = indicator_tokens
        if "images" in mm_data:
            visual_indicators = [
                hf_processor.construct_visual_indicators((1, 1, 1), False)
                for grid in processed_outputs["grids"]
            ]
            indicator_tokens = [
                self.visual_indicators_to_visual_tokens(indicator)
                for indicator in visual_indicators
            ]

            processed_outputs["indicator_tokens"] = indicator_tokens
        return processed_outputs

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        return prompt_tokens

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _ovis2_5_field_config()

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> list[PromptReplacement]:
        def get_replacement_ovis(item_idx, modality: str):
            if modality == "image":
                out_item = out_mm_kwargs["image"][item_idx]
                grid = out_item["grids"].data
            elif modality == "video":
                out_item = out_mm_kwargs["video"][item_idx]
                grid = out_item["video_grids"].data
            hf_processor = self.info.get_hf_processor()
424
425
426
            return hf_processor.construct_visual_placeholders(
                grid[0],
            )
427
428
429
430
431
432

        return [
            PromptReplacement(
                modality=modality,
                target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN,
                replacement=partial(get_replacement_ovis, modality=modality),
433
434
            )
            for modality in ("image", "video")
435
436
437
        ]


438
439
440
441
442
@MULTIMODAL_REGISTRY.register_processor(
    Ovis2_5MultiModalProcessor,
    info=Ovis2_5ProcessingInfo,
    dummy_inputs=Ovis2_5DummyInputsBuilder,
)
443
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config: PretrainedConfig = config
        self.llm = init_vllm_registered_model(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "llm"),
        )

        self.visual_tokenizer = VisualTokenizer(
            config=config.vit_config,
            visual_vocab_size=config.visual_vocab_size,
            quant_config=quant_config,
            prefix=f"{prefix}.visual_tokenizer",
        )

462
        self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
463
464
465
466

        text_model_type = self.config.get_text_config().model_type
        self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]

467
        self.make_empty_intermediate_tensors = (
468
469
            self.get_language_model().make_empty_intermediate_tensors
        )
470

471
    def _parse_and_validate_image_input(
472
473
        self, **kwargs: object
    ) -> Optional[OvisImagePatchInputs]:
474
475
476
        pixel_values = kwargs.pop("pixel_values", None)
        indicator_tokens = kwargs.pop("indicator_tokens", None)
        grids = kwargs.pop("grids", None)
477
478
479
480
481
        if pixel_values is None and indicator_tokens is None:
            return None

        if pixel_values is not None and indicator_tokens is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
482
483
484
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
485
486

            if not isinstance(indicator_tokens, (torch.Tensor, list)):
487
488
489
490
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(indicator_tokens)}"
                )
491
492
493
494
495
496
497
498

            return OvisImagePatchInputs(
                type="image_patches",
                flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
                patches_per_image=[
                    x.shape[0] // (self.config.vit_config.hidden_stride**2)
                    for x in flatten_bn(pixel_values)
                ],
499
                indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
500
501
502
503
504
                grids=flatten_bn(flatten_bn(grids), concat=True),
            )

        raise AssertionError("This line should be unreachable.")

505
    def _parse_and_validate_video_input(
506
507
        self, **kwargs: object
    ) -> Optional[OvisImagePatchInputs]:
508
509
510
511
512
513
514
515
        pixel_values = kwargs.pop("video_pixel_values", None)
        indicator_tokens = kwargs.pop("video_indicator_tokens", None)
        grids = kwargs.pop("video_grids", None)
        if pixel_values is None and indicator_tokens is None:
            return None

        if pixel_values is not None and indicator_tokens is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
516
517
518
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
519
520

            if not isinstance(indicator_tokens, (torch.Tensor, list)):
521
522
523
524
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(indicator_tokens)}"
                )
525
526
527
528
529
530
531
532

            return OvisVideoPatchInputs(
                type="video_patches",
                flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
                patches_per_image=[
                    x.shape[0] // (self.config.vit_config.hidden_stride**2)
                    for x in flatten_bn(pixel_values)
                ],
533
                indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
534
535
536
537
538
                grids=flatten_bn(flatten_bn(grids), concat=True),
            )

        raise AssertionError("This line should be unreachable.")

539
    def _process_image_input(
540
541
        self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
    ) -> MultiModalEmbeddings:
542
543
544
545
546
547
        image_patches_flat = image_input["flat_data"]
        patches_per_image = image_input["patches_per_image"]
        indicator_tokens = image_input["indicator_tokens"]
        grid_thws = image_input["grids"]

        indicator_per_image = list(
548
549
            map(lambda x: 2 if x > 1 else x + 2, patches_per_image)
        )
550
551
552

        target_dtype = self.visual_tokenizer.dtype
        visual_tokens = self.visual_tokenizer(
553
554
            image_patches_flat.to(target_dtype), grid_thws
        )
555
556
557
558
559

        visual_embeds = self.vte(visual_tokens)  # 1:1 numeric eq.
        indicator_embeds = self.vte(indicator_tokens)

        visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
560
        indicator_embeds_per_image = indicator_embeds.split(indicator_per_image)
561
562

        vision_embeddings = []
563
564
565
        for indicator, visual in zip(
            indicator_embeds_per_image, visual_embeds_per_image
        ):
566
567
568
569
            vision_embeddings_per_image = []
            visual = visual.unsqueeze(0)
            for i in range(visual.shape[0]):
                vision_embeddings_per_image.append(
570
571
572
573
                    torch.cat([indicator[i : i + 1], visual[i]], dim=0)
                )
            vision_embeddings_per_image.append(indicator[i + 1 :])
            vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0))
574
575
        return tuple(vision_embeddings)

576
577
578
579
580
581
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
582
583
584
585
586
587
588
589
590
591
592
            if (
                input_key in ("pixel_values", "indicator_tokens", "grids")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key
                in ("video_pixel_values", "video_indicator_tokens", "video_grids")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
593
594
595

        return modalities

596
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()
        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_image_input(video_input)
                multimodal_embeddings += video_embeddings

        return multimodal_embeddings
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # up until here we have a inputs_embeds 100% numerical identity
        # between the OG HF Transformers implementation and ours
        hidden_states = self.llm(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
641
        logits = self.llm.compute_logits(hidden_states)
642
643
        return logits

644
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
645
646
647
648
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_language_model(self) -> torch.nn.Module:
649
        return self.llm