ovis2_5.py 23.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""PyTorch Ovis model."""

5
6
from collections.abc import Iterable, Mapping
from functools import partial
7
from typing import Annotated, Literal
8
9
10
11
12

import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig

13
from vllm.config import VllmConfig
14
from vllm.config.multimodal import BaseDummyOptions
15
from vllm.model_executor.layers.linear import ReplicatedLinear
16
from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.model_executor.models.ovis import VisualEmbedding
18
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
19
20
21
22
23
24
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
31
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
32
from vllm.multimodal.processing import (
33
    BaseDummyInputsBuilder,
34
35
36
37
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
)
38
from vllm.renderers import TokenizeParams
39
40
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
41
from vllm.utils.tensor_schema import TensorSchema, TensorShape
42

43
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
44
45

IMAGE_TOKEN = "<image>"
46
IMAGE_PLACEHOLDER_ID = 151669
47
VIDEO_TOKEN = "<video>"
48
49
50
51
VIDEO_PLACEHOLDER_ID = 151670
INDICATOR_IDS = [151672, 151673, 151674, 151675]
IMAGE_PAD_TOKEN_ID = 151655
THINK_END_TOKEN_ID = 151668
52
53


54
class Ovis2_5ImagePatchInputs(TensorSchema):
55
    """
56
57
58
59
60
    Dimensions:
        - bnp: Batch size * number of images * number of patches
        - patch_size: patch_size_x * patch_size_y * num_channels
        - patch_indicators: Batch size * (number of patches + 1)
        - bn: Batch size * number of images
61
62
    """

63
64
65
66
67
68
69
    type: Literal["image_patches"]
    flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
    indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
    patches_per_item: Annotated[list[int], TensorShape("bn")]
    grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
    # This is used to restore the first two dimensions of `flat_data`.

70

71
class Ovis2_5VideoPatchInputs(TensorSchema):
72
    """
73
74
75
76
77
    Dimensions:
        - bnp: Batch size * number of videos * number of patches
        - patch_size: patch_size_x * patch_size_y * num_channels
        - patch_indicators: Batch size * (number of patches + 1)
        - bn: Batch size * number of videos
78
79
    """

80
81
82
83
84
85
    type: Literal["video_patches"]
    flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
    indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
    patches_per_item: Annotated[list[int], TensorShape("bn")]
    grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
    # This is used to restore the first two dimensions of `flat_data`.
86
87
88
89
90
91
92
93
94
95
96


class VisualTokenizer(torch.nn.Module):
    """
    VIT
    """

    def __init__(
        self,
        config: PretrainedConfig,
        visual_vocab_size: int,
97
        quant_config: QuantizationConfig | None = None,
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.vit = self._init_backbone(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.vit",
        )
        # reserved tokens for INDICATOR_IDS
        head_dim = visual_vocab_size - len(INDICATOR_IDS)
        self.head = torch.nn.Sequential(
            ReplicatedLinear(
                self.config.hidden_size * self.config.hidden_stride**2,
                head_dim,
                bias=False,
                return_bias=False,
115
116
117
            ),
            torch.nn.LayerNorm(head_dim),
        )
118
119
120
121

    def _init_backbone(
        self,
        config: PretrainedConfig,
122
        quant_config: QuantizationConfig | None = None,
123
124
125
126
        prefix: str = "",
    ):
        model_type = config.model_type
        if model_type == "siglip2_navit":
127
128
129
130
131
132
            return Siglip2NavitModel(
                config=config,
                quant_config=quant_config,
                prefix=prefix,
            )
        raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}")
133
134

    @property
135
    def dtype(self) -> torch.dtype:
136
137
138
        return next(self.head.parameters()).dtype

    @property
139
    def device(self) -> torch.device:
140
141
        return next(self.head.parameters()).device

142
    def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
143
        tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype)
144
145
        return tokens

146
147
148
    def encode(
        self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
    ) -> torch.Tensor:
149
        features = self.vit(pixel_values, grid_thws)
150
151
        # refer to qwen2.5-vl patchmerger
        seq_len, _ = features.shape
152
        features = features.reshape(seq_len // (self.config.hidden_stride**2), -1)
153
154
155

        return features

156
157
158
    def forward(
        self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
    ) -> torch.Tensor:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        features = self.encode(pixel_values, grid_thws)
        logits = self.head(features)
        tokens = self.tokenize(logits)
        # tokens' shape is [#Token, VocabSize-4],
        # so padding with [#Token, 4], after which,
        # tokens' shape should become [#Token, VocabSize];
        tokens = torch.nn.functional.pad(
            tokens,
            (0, len(INDICATOR_IDS)),
            mode="constant",
            value=0,
        )
        return tokens


class Ovis2_5ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config()

    def get_hf_processor(self, **kwargs):
        vit_config = self.get_hf_config().vit_config
        return self.ctx.get_hf_processor(
            Ovis2_5Processor,
            patch_size=vit_config.patch_size,
            hidden_stride=vit_config.hidden_stride,
            temporal_patch_size=vit_config.temporal_patch_size,
        )

187
188
189
    def get_default_tok_params(self) -> TokenizeParams:
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

190
191
192
    def get_image_processor(self) -> BaseImageProcessor:
        return self.get_hf_processor().image_processor  # type: ignore

193
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
194
195
196
197
198
199
200
201
202
203
204
205
206
        return {"image": None, "video": 1}

    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code
        # TODO(myselvess): Be adjusted based on the max_pixels
        return ImageSize(width=1792, height=1792)

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
207
    ) -> int:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        hf_config = self.get_hf_config()
        vit_config = hf_config.vit_config
        patch_size = vit_config.patch_size
        temporal_patch_size = vit_config.temporal_patch_size
        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + (-num_frames % temporal_patch_size)
        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = image_height // patch_size
        grid_w = image_width // patch_size
        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches
        return num_vision_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
224
225
226
        return self.get_num_image_tokens(
            image_width=target_width, image_height=target_height
        )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        num_frames = 0
        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens:
                break
            num_frames = next_num_frames
        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
        max_image_tokens = self.get_max_image_tokens() * max_images
251
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
252
253
254
255
256
257
258
259
260
261
        max_frames_per_video = max_total_frames // max(max_videos, 1)
        return max(max_frames_per_video, 1)

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
262
263
264
        num_video_tokens = self.get_num_image_tokens(
            image_width=image_width, image_height=image_height, num_frames=num_frames
        )
265
266
267
268
269
270
271
272
273
274
275
        return num_video_tokens

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
276
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
277
278
279
280
281
282
283
284
285
286
287
288
289
        )


class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
        return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
290
        mm_options: Mapping[str, BaseDummyOptions],
291
292
293
294
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

295
296
297
298
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
299

300
301
        image_overrides = mm_options.get("image")
        video_overrides = mm_options.get("video")
302

303
        mm_data = {
304
305
306
307
308
309
310
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
311
312
313
314
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
315
                overrides=video_overrides,
316
            ),
317
318
319
320
        }
        return mm_data


321
class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]):
322
323
324
325
326
    def visual_indicators_to_visual_tokens(
        self,
        visual_indicators: list[int],
    ) -> list[int]:
        """
327
        Filter image indicators placeholders and convert them to corresponding
328
329
330
331
332
        tokens in visual tokenizer.
        """
        hf_config = self.info.get_hf_config()
        vte_vocab_size = hf_config.visual_vocab_size
        return [
333
            vte_vocab_size - len(INDICATOR_IDS) + (x - INDICATOR_IDS[0])
334
            for x in visual_indicators
335
            if x >= INDICATOR_IDS[0]
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        ]

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        hf_processor = self.info.get_hf_processor()

        if "videos" in mm_data:
            visual_indicators = [
                hf_processor.construct_visual_indicators((1, 1, 1), True)
                for grid in processed_outputs["video_grids"]
            ]
            indicator_tokens = [
                self.visual_indicators_to_visual_tokens(indicator)
                for indicator in visual_indicators
            ]
368
            processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens)
369
370
371
372
373
374
375
376
377
378
        if "images" in mm_data:
            visual_indicators = [
                hf_processor.construct_visual_indicators((1, 1, 1), False)
                for grid in processed_outputs["grids"]
            ]
            indicator_tokens = [
                self.visual_indicators_to_visual_tokens(indicator)
                for indicator in visual_indicators
            ]

379
            processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
380
381
382
383
384
385
386
387
388
389
390
391
392
        return processed_outputs

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        return prompt_tokens

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
393
394
395
396
397
398
399
400
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            grids=MultiModalFieldConfig.batched("image"),
            indicator_tokens=MultiModalFieldConfig.batched("image"),
            video_pixel_values=MultiModalFieldConfig.batched("video"),
            video_indicator_tokens=MultiModalFieldConfig.batched("video"),
            video_grids=MultiModalFieldConfig.batched("video"),
        )
401
402
403
404
405
406
407

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> list[PromptReplacement]:
408
409
410
411
412
413
414
415
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        placeholder = {
            "image": vocab[IMAGE_TOKEN],
            "video": vocab[VIDEO_TOKEN],
        }

416
417
418
419
420
421
422
423
        def get_replacement_ovis(item_idx, modality: str):
            if modality == "image":
                out_item = out_mm_kwargs["image"][item_idx]
                grid = out_item["grids"].data
            elif modality == "video":
                out_item = out_mm_kwargs["video"][item_idx]
                grid = out_item["video_grids"].data
            hf_processor = self.info.get_hf_processor()
424
425
426
            return hf_processor.construct_visual_placeholders(
                grid[0],
            )
427
428
429
430

        return [
            PromptReplacement(
                modality=modality,
431
                target=[placeholder[modality]],
432
                replacement=partial(get_replacement_ovis, modality=modality),
433
434
            )
            for modality in ("image", "video")
435
436
437
        ]


438
439
440
441
442
@MULTIMODAL_REGISTRY.register_processor(
    Ovis2_5MultiModalProcessor,
    info=Ovis2_5ProcessingInfo,
    dummy_inputs=Ovis2_5DummyInputsBuilder,
)
443
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
444
445
446
447
448
449
450
451
452
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return IMAGE_TOKEN
        if modality.startswith("video"):
            return VIDEO_TOKEN

        raise ValueError("Only image or video modality is supported")

453
454
455
456
457
458
459
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config: PretrainedConfig = config

460
461
462
463
464
        with self._mark_language_model(vllm_config):
            self.llm = init_vllm_registered_model(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "llm"),
            )
465

466
467
468
469
470
471
472
473
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.visual_tokenizer = VisualTokenizer(
                config=config.vit_config,
                visual_vocab_size=config.visual_vocab_size,
                quant_config=quant_config,
                prefix=f"{prefix}.visual_tokenizer",
            )
            self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
474

475
        self.image_pad_token_id: int = IMAGE_PAD_TOKEN_ID
476

477
        self.make_empty_intermediate_tensors = (
478
479
            self.get_language_model().make_empty_intermediate_tensors
        )
480

481
    def _parse_and_validate_image_input(
482
        self, **kwargs: object
483
    ) -> Ovis2_5ImagePatchInputs | None:
484
485
486
        pixel_values = kwargs.pop("pixel_values", None)
        indicator_tokens = kwargs.pop("indicator_tokens", None)
        grids = kwargs.pop("grids", None)
487
488
489
490
491
        if pixel_values is None and indicator_tokens is None:
            return None

        if pixel_values is not None and indicator_tokens is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
492
493
494
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
495
496

            if not isinstance(indicator_tokens, (torch.Tensor, list)):
497
498
499
500
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(indicator_tokens)}"
                )
501

502
            return Ovis2_5ImagePatchInputs(
503
                type="image_patches",
504
505
                flat_data=flatten_bn(pixel_values, concat=True),
                patches_per_item=[
506
                    x.shape[0] // (self.config.vit_config.hidden_stride**2)
507
                    for x in pixel_values
508
                ],
509
510
                indicator_tokens=flatten_bn(indicator_tokens, concat=True),
                grids=flatten_bn(grids, concat=True),
511
512
513
514
            )

        raise AssertionError("This line should be unreachable.")

515
    def _parse_and_validate_video_input(
516
        self, **kwargs: object
517
    ) -> Ovis2_5VideoPatchInputs | None:
518
519
520
521
522
523
524
525
        pixel_values = kwargs.pop("video_pixel_values", None)
        indicator_tokens = kwargs.pop("video_indicator_tokens", None)
        grids = kwargs.pop("video_grids", None)
        if pixel_values is None and indicator_tokens is None:
            return None

        if pixel_values is not None and indicator_tokens is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
526
527
528
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
529
530

            if not isinstance(indicator_tokens, (torch.Tensor, list)):
531
532
533
534
                raise ValueError(
                    "Incorrect type of indicator_tokens. "
                    f"Got type: {type(indicator_tokens)}"
                )
535

536
            return Ovis2_5VideoPatchInputs(
537
                type="video_patches",
538
539
                flat_data=flatten_bn(pixel_values, concat=True),
                patches_per_item=[
540
                    x.shape[0] // (self.config.vit_config.hidden_stride**2)
541
                    for x in pixel_values
542
                ],
543
544
                indicator_tokens=flatten_bn(indicator_tokens, concat=True),
                grids=flatten_bn(grids, concat=True),
545
546
547
548
            )

        raise AssertionError("This line should be unreachable.")

549
    def _process_visual_input(
550
        self, visual_input: Ovis2_5ImagePatchInputs | Ovis2_5VideoPatchInputs
551
    ) -> MultiModalEmbeddings:
552
553
554
555
        image_patches_flat = visual_input["flat_data"]
        patches_per_image = visual_input["patches_per_item"]
        indicator_tokens = visual_input["indicator_tokens"]
        grid_thws = visual_input["grids"]
556
557

        indicator_per_image = list(
558
559
            map(lambda x: 2 if x > 1 else x + 2, patches_per_image)
        )
560
561
562

        target_dtype = self.visual_tokenizer.dtype
        visual_tokens = self.visual_tokenizer(
563
564
            image_patches_flat.to(target_dtype), grid_thws
        )
565
566
567
568
569

        visual_embeds = self.vte(visual_tokens)  # 1:1 numeric eq.
        indicator_embeds = self.vte(indicator_tokens)

        visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
570
        indicator_embeds_per_image = indicator_embeds.split(indicator_per_image)
571
572

        vision_embeddings = []
573
574
575
        for indicator, visual in zip(
            indicator_embeds_per_image, visual_embeds_per_image
        ):
576
577
578
579
            vision_embeddings_per_image = []
            visual = visual.unsqueeze(0)
            for i in range(visual.shape[0]):
                vision_embeddings_per_image.append(
580
581
582
583
                    torch.cat([indicator[i : i + 1], visual[i]], dim=0)
                )
            vision_embeddings_per_image.append(indicator[i + 1 :])
            vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0))
584
585
        return tuple(vision_embeddings)

586
587
588
589
590
591
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
592
593
594
595
596
597
598
599
600
601
602
            if (
                input_key in ("pixel_values", "indicator_tokens", "grids")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key
                in ("video_pixel_values", "video_indicator_tokens", "video_grids")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
603
604
605

        return modalities

606
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
607
608
609
610
611
612
613
614
615
616
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()
        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
617
618
                image_embeddings = self._process_visual_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
619
620
            if modality == "videos":
                video_input = modalities["videos"]
621
                video_embeddings = self._process_visual_input(video_input)
622
                multimodal_embeddings += tuple(video_embeddings)
623
624

        return multimodal_embeddings
625
626
627

    def forward(
        self,
628
        input_ids: torch.Tensor | None,
629
        positions: torch.Tensor,
630
631
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
632
        **kwargs: object,
633
    ) -> torch.Tensor | IntermediateTensors:
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        if intermediate_tensors is not None:
            inputs_embeds = None

        # up until here we have a inputs_embeds 100% numerical identity
        # between the OG HF Transformers implementation and ours
        hidden_states = self.llm(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
650
    ) -> torch.Tensor | None:
651
        return self.llm.compute_logits(hidden_states)
652

653
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
654
655
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)