llava_onevision.py 32.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor
11
from transformers.models.llava_onevision.modeling_llava_onevision import (
12
13
14
    get_anyres_image_grid_shape,
    unpad_image,
)
15

16
from vllm.config import VllmConfig
17
from vllm.config.multimodal import BaseDummyOptions
18
19
from vllm.model_executor.layers.activation import get_act_fn
from vllm.multimodal import MULTIMODAL_REGISTRY
20
21
22
23
24
25
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageSize,
    MultiModalDataItems,
    VideoEmbeddingItems,
    VideoProcessorItems,
)
31
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.tensor_schema import TensorSchema, TensorShape
34

35
from .clip import CLIPVisionModel
36
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
37
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
38
39
40
41
42
from .llava_next import (
    BaseLlavaNextMultiModalProcessor,
    LlavaNextLikeConfig,
    LlavaNextProcessingInfo,
)
43
from .siglip import SiglipVisionModel
44
45
46
47
48
49
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
50

51
52
53
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

54

55
class LlavaOnevisionVideoPixelInputs(TensorSchema):
56
    """
57
58
59
60
61
62
63
    Dimensions:
        - bn: Batch size * number of videos
        - f: Number of frames
        - c: Number of channels (3)
        - h: Height
        - w: Width

64
        Note that `f` may be different for each batch, and 'num_frames'
65
66
        may be different for each video, in which case the data is passed as a
        list instead of a batched tensor.
67
    """
68

69
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
70

71
    pixel_values_videos: Annotated[
72
        torch.Tensor | list[torch.Tensor],
73
74
        TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
    ]
75
76


77
class LlavaOnevisionImagePixelInputs(TensorSchema):
78
    """
79
80
81
82
83
84
85
86
87
    Dimensions:
        - bn: Batch size * number of images
        - np: Number of patches (1 + num_patches)
        - c: Number of channels (3)
        - h: Height
        - w: Width

        Note that `num_patches` may be different per batch and image,
        in which case the data is passed as a list instead of a batched tensor.
88
    """
89

90
    type: Literal["pixel_values"] = "pixel_values"
91

92
    pixel_values: Annotated[
93
        torch.Tensor | list[torch.Tensor],
94
        TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
95
    ]
96

97
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
98
99


100
class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
101
    """
102
103
104
105
106
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
107

108
109
110
111
112
113
    type: Literal["image_embeds"] = "image_embeds"

    data: Annotated[
        torch.Tensor,
        TensorShape("bn", "ifs", "hs"),
    ]
114
115


116
117
118
LlavaOnevisionImageInputs: TypeAlias = (
    LlavaOnevisionImagePixelInputs | LlavaOnevisionImageEmbeddingInputs
)
119

120
121
122
LlavaOnevisionMultiInputs: TypeAlias = (
    LlavaOnevisionImageInputs | LlavaOnevisionVideoPixelInputs
)
123
124


125
126
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
    video_token_index: Final[int]
127
128


129
130
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
    def get_hf_config(self) -> LlavaOnevisionLikeConfig:
131
        return self.ctx.get_hf_config(LlavaOnevisionConfig)
132

133
134
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs)
135

136
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
137
138
        return {"image": None, "video": None}

139
140
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
    # with additional logic afterwards taken from LlavaOnevisionProcessor
141
142
143
144
145
146
147
148
149
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
150
151
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
152

153
154
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
155

156
        if aspect_ratio > current_aspect_ratio:
157
            new_height = int(
158
159
                round(original_height * (current_width / original_width), 7)
            )
160
161
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
162
        else:
163
            new_width = int(
164
165
                round(original_width * (current_height / original_height), 7)
            )
166
167
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
168

169
170
        unpadded_features = current_height * current_width
        newline_features = current_height
171

172
        ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
173
        if ratio > 1.1:
174
175
            height_factor = int(current_height // ratio)
            width_factor = int(current_width // ratio)
176
177
            unpadded_features = height_factor * width_factor
            newline_features = height_factor
178
179
180

        return (unpadded_features, newline_features)

181
182
183
184
    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE: This hardcoded value is found via processor tests
        return ImageSize(width=1153, height=944)

185
186
187
188
189
190
    def _get_num_frame_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
191
        hf_config = self.get_hf_config()
192
193
        spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

194
        vision_encoder_info = self.get_vision_encoder_info()
195
        patch_grid_length = vision_encoder_info.get_patch_grid_length()
196
197
198
199
        pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)

        return pooled_grid_length * pooled_grid_length

200
    def get_num_video_tokens(
201
202
203
204
205
206
207
208
209
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        num_frame_tokens = self._get_num_frame_tokens(
            image_width=image_width,
            image_height=image_height,
210
211
        )

212
        return num_frame_tokens * num_frames + 1  # Newline token
213

214
    def _get_max_video_frames(self, max_tokens: int) -> int:
215
        target_width, target_height = self.get_image_size_with_most_features()
216

217
        num_frames = 0
218

219
220
        while True:
            next_num_frames = num_frames + 1
221
            next_max_tokens = self.get_num_video_tokens(
222
223
224
225
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
226

227
            if next_max_tokens > max_tokens:
228
                break
229

230
            num_frames = next_num_frames
231

232
233
        return num_frames

234
235
236
237
238
239
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_videos = mm_counts.get("video", 0)
240

241
        max_total_frames = self._get_max_video_frames(seq_len)
242
243
244
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
245

246
        return max(max_frames_per_video, 1)
247

248
249
250
251
252
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
253
        target_width, target_height = self.get_image_size_with_most_features()
254

255
        return self.get_num_video_tokens(
256
257
            image_width=target_width,
            image_height=target_height,
258
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
259
260
        )

261
262

class LlavaOnevisionDummyInputsBuilder(
263
264
    LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]
):
265
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
266
267
268
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

269
        processor = self.info.get_hf_processor()
270
271
        image_token = processor.image_token
        video_token = processor.video_token
272

273
274
275
276
277
278
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
279
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
280
281
282
283
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

284
285
286
287
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
288

289
290
291
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

292
        return {
293
294
295
296
297
298
299
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
300
301
                width=target_width,
                height=target_height,
302
                num_frames=target_num_frames,
303
                num_videos=num_videos,
304
                overrides=video_overrides,
305
            ),
306
307
308
        }


309
class LlavaOnevisionMultiModalProcessor(
310
311
    BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]
):
312
313
314
315
316
317
318
319
320
321
322
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
        )
323
324
325
326
327
328

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
329
        tok_kwargs: Mapping[str, object],
330
331
332
333
334
335
336
337
338
339
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        assert isinstance(videos, list)

        if not videos:
            return super()._call_hf_processor(
                prompt=prompt,
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
340
                tok_kwargs=tok_kwargs,
341
342
            )

343
344
345
346
        # LLaVA-OneVision processor doesn't support multiple videos
        # with different sizes when converting back to tensors
        # So, we process each component separately
        # NOTE: No prompt replacement is applied in this case
347
        processor = self.info.get_hf_processor()
348
        image_token = processor.image_token
349
        video_token = processor.video_token
350

351
        text_outputs = super()._call_hf_processor(
352
            prompt=prompt,
353
            mm_data={},
354
            mm_kwargs=mm_kwargs,
355
            tok_kwargs=tok_kwargs,
356
357
        )

358
359
360
361
362
363
364
        images = mm_data.pop("images", [])
        assert isinstance(images, list)
        if images:
            processor_outputs = super()._call_hf_processor(
                prompt=image_token * len(images),
                mm_data={"images": images},
                mm_kwargs=mm_kwargs,
365
                tok_kwargs=tok_kwargs,
366
367
368
369
370
371
372
373
374
            )
            image_outputs = {
                k: v
                for k, v in processor_outputs.items()
                if k in ("pixel_values", "image_sizes")
            }
        else:
            image_outputs = {}

375
376
377
        pixel_values_videos = []
        for video in videos:
            item_outputs = super()._call_hf_processor(
378
379
                prompt=video_token,
                mm_data={"videos": video},
380
                mm_kwargs=mm_kwargs,
381
                tok_kwargs=tok_kwargs,
382
            )
383

384
385
386
            pixel_values_videos.append(item_outputs["pixel_values_videos"][0])

        video_outputs = {"pixel_values_videos": pixel_values_videos}
387

388
        combined_outputs = dict(
389
390
391
            text_outputs,
            **image_outputs,
            **video_outputs,
392
393
394
        )
        return BatchFeature(combined_outputs)

395
    def _hf_processor_applies_updates(
396
397
398
399
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
400
        tokenization_kwargs: Mapping[str, object],
401
    ) -> bool:
402
        base_result = super()._hf_processor_applies_updates(
403
404
405
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
406
            tokenization_kwargs=tokenization_kwargs,
407
408
409
410
        )

        return base_result and mm_items.get_count("video", strict=False) == 0

411
    def _get_prompt_updates(
412
413
414
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
415
        out_mm_kwargs: MultiModalKwargsItems,
416
417
    ) -> Sequence[PromptUpdate]:
        image_repls = super()._get_prompt_updates(
418
419
420
421
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
422

423
        hf_config = self.info.get_hf_config()
424
425
426
427
        video_token_id = hf_config.video_token_index

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
428
429
                "video", (VideoEmbeddingItems, VideoProcessorItems)
            )
430
431
432
433
434

            if isinstance(videos, VideoEmbeddingItems):
                num_video_tokens = videos.get_feature_size(item_idx)
            else:
                image_size = videos.get_frame_size(item_idx)
435
                num_video_tokens = self.info.get_num_video_tokens(
436
437
438
439
440
441
442
                    image_width=image_size.width,
                    image_height=image_size.height,
                    num_frames=videos.get_num_frames(item_idx),
                )

            return [video_token_id] * num_video_tokens

443
444
        return [
            *image_repls,
445
446
447
448
449
450
451
            PromptReplacement(
                modality="video",
                target=[video_token_id],
                replacement=get_video_replacement,
            ),
        ]

452
453
454
455
456

class LlavaOnevisionMultiModalProjector(nn.Module):
    def __init__(self, config: LlavaOnevisionConfig):
        super().__init__()

457
458
459
460
461
        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.hidden_size,
            bias=config.multimodal_projector_bias,
        )
462
        self.act = get_act_fn(config.projector_hidden_act)
463
464
465
466
467
        self.linear_2 = nn.Linear(
            config.text_config.hidden_size,
            config.text_config.hidden_size,
            bias=config.multimodal_projector_bias,
        )
468
469
470
471
472
473
474
475

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


476
477
478
@MULTIMODAL_REGISTRY.register_processor(
    LlavaOnevisionMultiModalProcessor,
    info=LlavaOnevisionProcessingInfo,
479
480
481
    dummy_inputs=LlavaOnevisionDummyInputsBuilder,
)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
482
483
484
485
486
487
488
489
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
490
491
        }
    )
492

493
    @classmethod
494
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
495
496
497
498
499
500
501
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

502
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
503
        super().__init__()
504
505
506
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
507
508
509
510

        self.config = config
        self.multimodal_config = multimodal_config

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            # Initialize the vision tower only up to the required feature layer
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config=quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.image_newline = nn.Parameter(
                torch.empty(config.text_config.hidden_size)
            )
            self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
530

531
        self.make_empty_intermediate_tensors = (
532
533
            self.language_model.model.make_empty_intermediate_tensors
        )
534

535
    def _parse_and_validate_image_input(
536
        self, **kwargs: object
537
    ) -> LlavaOnevisionImageInputs | None:
538
539
540
541
542
543
544
545
546
547
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return LlavaOnevisionImagePixelInputs(
                type="pixel_values",
548
549
                pixel_values=pixel_values,
                image_sizes=image_sizes,
550
551
                resolve_bindings={
                    "h": self.config.vision_config.image_size,
552
553
554
                    "w": self.config.vision_config.image_size,
                },
            )
555
556
557
558

        if image_embeds is not None:
            return LlavaOnevisionImageEmbeddingInputs(
                type="image_embeds",
559
                data=image_embeds,
560
561
562
563
564
            )

        raise AssertionError("This line should be unreachable.")

    def _parse_and_validate_video_input(
565
        self, **kwargs: object
566
    ) -> LlavaOnevisionVideoPixelInputs | None:
567
568
569
        """
        A legal video input should have the following dimensions:
        {
570
            "pixel_values_videos" :
571
                list[b, Tensor(nb_frames, nb_channels, height, width)]
572
573
        }
        """
574
575
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        if pixel_values_videos is None:
576
577
578
579
            return None

        return LlavaOnevisionVideoPixelInputs(
            type="pixel_values_videos",
580
            pixel_values_videos=pixel_values_videos,
581
582
            resolve_bindings={
                "h": self.config.vision_config.image_size,
583
584
585
                "w": self.config.vision_config.image_size,
            },
        )
586
587

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
588
        mm_input_by_modality = {}
589

590
591
592
        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
607

608
        return mm_input_by_modality
609
610
611

    def _image_pixels_to_features(
        self,
612
        vision_tower: CLIPVisionModel | SiglipVisionModel,
613
614
615
616
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
617
618
619
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
620
621
622
        )

    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
623
624
625
626
627
628
629
630
631
    def _merge_image_patch_embeddings(
        self,
        image_size: torch.Tensor,
        patch_embeddings: torch.Tensor,
        *,
        image_newline=None,
        vision_aspect_ratio="anyres_max_9",
        strategy: str,
    ) -> torch.Tensor:
632
633
634
635
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
636
637
            height = width = (
                self.config.vision_config.image_size
638
                // self.config.vision_config.patch_size
639
            )
640
641
642
643

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
644
645
                    "The number of patches is not consistent with the image size."
                )
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

                # image_aspect_ratio == "anyres"
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
662
663
664
                other_patch_embeds = other_patch_embeds[:num_patches].view(
                    num_patch_height, num_patch_width, height, width, -1
                )
665
666

                if "unpad" in strategy:
667
668
669
670
671
672
673
674
675
                    other_patch_embeds = (
                        other_patch_embeds.permute(4, 0, 2, 1, 3)
                        .contiguous()
                        .flatten(1, 2)
                        .flatten(2, 3)
                    )
                    other_patch_embeds = unpad_image(
                        other_patch_embeds, (orig_height, orig_width)
                    )
676
                    max_num_patches = int(
677
678
                        vision_aspect_ratio.removeprefix("anyres_max_")
                    )
679
                    channels, curr_height, curr_width = other_patch_embeds.shape
680
681
682
                    ratio = math.sqrt(
                        curr_height * curr_width / (max_num_patches * height**2)
                    )
683
684
685
                    if ratio > 1.1:
                        other_patch_embeds = other_patch_embeds[None]
                        other_patch_embeds = nn.functional.interpolate(
686
687
688
689
                            other_patch_embeds,
                            [int(curr_height // ratio), int(curr_width // ratio)],
                            mode="bilinear",
                        )[0]
690
691
692
693
                    if image_newline is not None:
                        other_patch_embeds = torch.cat(
                            (
                                other_patch_embeds,
694
695
                                image_newline[:, None, None]
                                .expand(*other_patch_embeds.shape[:-1], 1)
696
697
                                .to(other_patch_embeds.device),
                            ),
698
699
700
701
702
                            dim=-1,
                        )
                    other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
                        0, 1
                    )
703
                else:
704
705
706
                    other_patch_embeds = (
                        other_patch_embeds.permute(0, 2, 1, 3, 4)
                        .contiguous()
707
                        .flatten(0, 3)
708
                    )
709
710

                merged_patch_embeddings = torch.cat(
711
712
                    (base_patch_embeds, other_patch_embeds), dim=0
                )
713
714
715
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
716
717
718
719
720
721
                        (
                            base_patch_embeds,
                            self.image_newline[None].to(base_patch_embeds.device),
                        ),
                        dim=0,
                    )
722
723
724
725
726
727
728
729
730
731
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
        self,
        inputs: LlavaOnevisionImagePixelInputs,
732
    ) -> torch.Tensor | list[torch.Tensor]:
733
        pixel_values = inputs["pixel_values"]
734
735
736
737
738

        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
739
740
                self.vision_tower, stacked_pixel_values
            )
741
            stacked_patch_embeddings = self.multi_modal_projector(
742
743
                stacked_image_features
            )
744
745

            return stacked_patch_embeddings.view(
746
747
                b, num_patches, *stacked_patch_embeddings.shape[1:]
            )
748
749
750
751

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
        stacked_image_features = self._image_pixels_to_features(
752
753
            self.vision_tower, stacked_pixel_values
        )
754
755

        return [
756
757
758
759
            self.multi_modal_projector(image_features)
            for image_features in torch.split(
                stacked_image_features, num_patches_per_batch
            )
760
761
762
763
764
        ]

    def _process_image_input(
        self,
        image_input: LlavaOnevisionImageInputs,
765
    ) -> torch.Tensor | list[torch.Tensor]:
766
        if image_input["type"] == "image_embeds":
767
            return image_input["data"]
768
769
770
771
772

        patch_embeddings = self._process_image_pixels(image_input)

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
773
            batch_size = len(image_input["pixel_values"])
774
775
            vision_config = self.config.vision_config
            default_height = default_width = vision_config.image_size
776
777
778
            image_sizes = torch.as_tensor(
                [[default_height, default_width] for _ in range(batch_size)]
            )
779
780
781
782
783
784

        return [
            self._merge_image_patch_embeddings(
                image_sizes[i],
                patch_features_batch,
                image_newline=self.image_newline,
785
786
                strategy="spatial_unpad",
            )
787
788
789
790
791
            for i, patch_features_batch in enumerate(patch_embeddings)
        ]

    def _video_pixels_to_features(
        self,
792
        vision_tower: CLIPVisionModel | SiglipVisionModel,
793
794
795
796
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
797
798
799
        video_features = vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
800
801
802
803
804
805
        )
        video_features = self.multi_modal_projector(video_features)
        video_features = self.apply_pooling(video_features)
        return video_features

    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
806
        video_pixels = inputs["pixel_values_videos"]
807
808

        if isinstance(video_pixels, torch.Tensor):
809
            total_videos, frames, c, h, w = video_pixels.shape
810
            video_pixels_flat = video_pixels.view(total_videos * frames, c, h, w)
811
812

            embeddings_flat = self._video_pixels_to_features(
813
814
                self.vision_tower, video_pixels_flat
            )
815
816

            embeddings_flat = embeddings_flat.reshape(
817
818
                total_videos, frames * embeddings_flat.shape[1], -1
            )
819
820

            image_newline = self.image_newline[None, None, :].expand(
821
822
                total_videos, -1, -1
            )
823
824
825
826
827
828
            return torch.cat((embeddings_flat, image_newline), dim=1)

        frames_per_video = [len(video) for video in video_pixels]
        video_pixels_flat = torch.cat(video_pixels)

        embeddings_flat = self._video_pixels_to_features(
829
830
            self.vision_tower, video_pixels_flat
        )
831
832
833
834
835
836

        image_newline = self.image_newline[None, None, :]

        return [
            torch.cat(
                (
837
                    embeds.reshape(1, num_frame * embeddings_flat.shape[1], -1),
838
839
840
                    image_newline,
                ),
                dim=1,
841
842
            )
            for num_frame, embeds in zip(
843
844
845
846
                frames_per_video,
                torch.split(embeddings_flat, frames_per_video),
            )
        ]
847

848
    def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
849
850
851
852
853
854
855
856
857
        vision_config = self.config.vision_config
        height = width = vision_config.image_size // vision_config.patch_size
        batch_frames, _, dim = image_features.shape
        image_features = image_features.view(batch_frames, height, width, -1)
        image_features = image_features.permute(0, 3, 1, 2)

        # TODO support other pooling types config
        height, width = image_features.shape[2:]
        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
858
859
860
        image_feature = nn.functional.interpolate(
            image_features, size=scaled_shape, mode="bilinear"
        )
861
862
863
864
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(batch_frames, -1, dim)
        return image_feature

865
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
866
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
867
        if not mm_input_by_modality:
868
            return []
869
870
            return None

871
        # The result multimodal_embeddings is tuple of tensors, with each
872
        # tensor corresponding to a multimodal data item (image or video).
873
874
875
876
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
877
878
879
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
880
881
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
882
883
            if modality == "video":
                video_embeddings = self._process_video_pixels(multimodal_input)
884
                multimodal_embeddings += tuple(video_embeddings)
885
886
887

        return multimodal_embeddings

888
889
    def forward(
        self,
890
        input_ids: torch.Tensor | None,
891
        positions: torch.Tensor,
892
893
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
894
        **kwargs: object,
895
    ) -> torch.Tensor | IntermediateTensors:
896
897
898
899
900
901
        """Run forward pass for LlaVA-Onevision.
        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            pixel_values_videos: Pixels in each frames for each input videos.
        """
902
        if intermediate_tensors is not None:
903
            inputs_embeds = None
904

905
906
907
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
908
909
910
911
912
913

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
914
    ) -> torch.Tensor | None:
915
        return self.language_model.compute_logits(hidden_states)
916

917
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
918
        loader = AutoWeightsLoader(self)
919
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)