llava_onevision.py 33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Optional, Protocol, Union
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor
11
from transformers.models.llava_onevision.modeling_llava_onevision import (
12
13
14
    get_anyres_image_grid_shape,
    unpad_image,
)
15

16
from vllm.config import VllmConfig
17
from vllm.config.multimodal import BaseDummyOptions
18
19
from vllm.model_executor.layers.activation import get_act_fn
from vllm.multimodal import MULTIMODAL_REGISTRY
20
21
22
23
24
25
26
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageSize,
    MultiModalDataItems,
    VideoEmbeddingItems,
    VideoProcessorItems,
)
31
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
32
from vllm.sequence import IntermediateTensors
33
from vllm.utils.tensor_schema import TensorSchema, TensorShape
34

35
from .clip import CLIPVisionModel
36
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
37
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
38
39
40
41
42
from .llava_next import (
    BaseLlavaNextMultiModalProcessor,
    LlavaNextLikeConfig,
    LlavaNextProcessingInfo,
)
43
from .siglip import SiglipVisionModel
44
45
46
47
48
49
50
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
51

52
53
54
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

55

56
class LlavaOnevisionVideoPixelInputs(TensorSchema):
57
    """
58
59
60
61
62
63
64
65
66
67
    Dimensions:
        - bn: Batch size * number of videos
        - f: Number of frames
        - c: Number of channels (3)
        - h: Height
        - w: Width

        Note that `num_videos` may be different for each batch, and 'num_frames'
        may be different for each video, in which case the data is passed as a
        list instead of a batched tensor.
68
    """
69

70
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
71

72
73
74
75
    pixel_values_videos: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
    ]
76
77


78
class LlavaOnevisionImagePixelInputs(TensorSchema):
79
    """
80
81
82
83
84
85
86
87
88
    Dimensions:
        - bn: Batch size * number of images
        - np: Number of patches (1 + num_patches)
        - c: Number of channels (3)
        - h: Height
        - w: Width

        Note that `num_patches` may be different per batch and image,
        in which case the data is passed as a list instead of a batched tensor.
89
    """
90

91
    type: Literal["pixel_values"] = "pixel_values"
92

93
94
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
95
        TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
96
    ]
97

98
    image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
99
100


101
class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
102
    """
103
104
105
106
107
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
108

109
110
111
112
113
114
    type: Literal["image_embeds"] = "image_embeds"

    data: Annotated[
        torch.Tensor,
        TensorShape("bn", "ifs", "hs"),
    ]
115
116


117
118
119
LlavaOnevisionImageInputs = Union[
    LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs
]
120

121
122
123
LlavaOnevisionMultiInputs = Union[
    LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs
]
124
125


126
127
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
    video_token_index: Final[int]
128
129


130
131
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
    def get_hf_config(self) -> LlavaOnevisionLikeConfig:
132
        return self.ctx.get_hf_config(LlavaOnevisionConfig)
133

134
135
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs)
136

137
138
139
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

140
141
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
    # with additional logic afterwards taken from LlavaOnevisionProcessor
142
143
144
145
146
147
148
149
150
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
151
152
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
153

154
155
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
156

157
        if aspect_ratio > current_aspect_ratio:
158
            new_height = int(
159
160
                round(original_height * (current_width / original_width), 7)
            )
161
162
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
163
        else:
164
            new_width = int(
165
166
                round(original_width * (current_height / original_height), 7)
            )
167
168
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
169

170
171
        unpadded_features = current_height * current_width
        newline_features = current_height
172

173
        ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
174
        if ratio > 1.1:
175
176
            height_factor = int(current_height // ratio)
            width_factor = int(current_width // ratio)
177
178
            unpadded_features = height_factor * width_factor
            newline_features = height_factor
179
180
181

        return (unpadded_features, newline_features)

182
183
184
185
    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE: This hardcoded value is found via processor tests
        return ImageSize(width=1153, height=944)

186
187
188
189
190
191
    def _get_num_frame_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
192
        hf_config = self.get_hf_config()
193
194
        spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

195
        vision_encoder_info = self.get_vision_encoder_info()
196
        patch_grid_length = vision_encoder_info.get_patch_grid_length()
197
198
199
200
        pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)

        return pooled_grid_length * pooled_grid_length

201
    def get_num_video_tokens(
202
203
204
205
206
207
208
209
210
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        num_frame_tokens = self._get_num_frame_tokens(
            image_width=image_width,
            image_height=image_height,
211
212
        )

213
        return num_frame_tokens * num_frames + 1  # Newline token
214

215
    def _get_max_video_frames(self, max_tokens: int) -> int:
216
        target_width, target_height = self.get_image_size_with_most_features()
217

218
        num_frames = 0
219

220
221
        while True:
            next_num_frames = num_frames + 1
222
            next_max_tokens = self.get_num_video_tokens(
223
224
225
226
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
227

228
            if next_max_tokens > max_tokens:
229
                break
230

231
            num_frames = next_num_frames
232

233
234
        return num_frames

235
236
237
238
239
240
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_videos = mm_counts.get("video", 0)
241

242
        max_total_frames = self._get_max_video_frames(seq_len)
243
244
245
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
246

247
        return max(max_frames_per_video, 1)
248

249
250
251
252
253
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
254
        target_width, target_height = self.get_image_size_with_most_features()
255

256
        return self.get_num_video_tokens(
257
258
            image_width=target_width,
            image_height=target_height,
259
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
260
261
        )

262
263

class LlavaOnevisionDummyInputsBuilder(
264
265
    LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]
):
266
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
267
268
269
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

270
        processor = self.info.get_hf_processor()
271
272
        image_token = processor.image_token
        video_token = processor.video_token
273

274
275
276
277
278
279
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
280
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
281
282
283
284
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

285
286
287
288
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
289

290
291
292
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

293
        return {
294
295
296
297
298
299
300
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
301
302
                width=target_width,
                height=target_height,
303
                num_frames=target_num_frames,
304
                num_videos=num_videos,
305
                overrides=video_overrides,
306
            ),
307
308
309
        }


310
class LlavaOnevisionMultiModalProcessor(
311
312
    BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]
):
313
314
315
316
317
318
319
320
321
322
323
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
        )
324
325
326
327
328
329

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
330
        tok_kwargs: Mapping[str, object],
331
332
333
334
335
336
337
338
339
340
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        assert isinstance(videos, list)

        if not videos:
            return super()._call_hf_processor(
                prompt=prompt,
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
341
                tok_kwargs=tok_kwargs,
342
343
            )

344
345
346
347
        # LLaVA-OneVision processor doesn't support multiple videos
        # with different sizes when converting back to tensors
        # So, we process each component separately
        # NOTE: No prompt replacement is applied in this case
348
        processor = self.info.get_hf_processor()
349
        image_token = processor.image_token
350
        video_token = processor.video_token
351

352
        text_outputs = super()._call_hf_processor(
353
            prompt=prompt,
354
            mm_data={},
355
            mm_kwargs=mm_kwargs,
356
            tok_kwargs=tok_kwargs,
357
358
        )

359
360
361
362
363
364
365
        images = mm_data.pop("images", [])
        assert isinstance(images, list)
        if images:
            processor_outputs = super()._call_hf_processor(
                prompt=image_token * len(images),
                mm_data={"images": images},
                mm_kwargs=mm_kwargs,
366
                tok_kwargs=tok_kwargs,
367
368
369
370
371
372
373
374
375
            )
            image_outputs = {
                k: v
                for k, v in processor_outputs.items()
                if k in ("pixel_values", "image_sizes")
            }
        else:
            image_outputs = {}

376
377
378
        pixel_values_videos = []
        for video in videos:
            item_outputs = super()._call_hf_processor(
379
380
                prompt=video_token,
                mm_data={"videos": video},
381
                mm_kwargs=mm_kwargs,
382
                tok_kwargs=tok_kwargs,
383
            )
384

385
386
387
            pixel_values_videos.append(item_outputs["pixel_values_videos"][0])

        video_outputs = {"pixel_values_videos": pixel_values_videos}
388

389
        combined_outputs = dict(
390
391
392
            text_outputs,
            **image_outputs,
            **video_outputs,
393
394
395
        )
        return BatchFeature(combined_outputs)

396
    def _hf_processor_applies_updates(
397
398
399
400
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
401
        tokenization_kwargs: Mapping[str, object],
402
    ) -> bool:
403
        base_result = super()._hf_processor_applies_updates(
404
405
406
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
407
            tokenization_kwargs=tokenization_kwargs,
408
409
410
411
        )

        return base_result and mm_items.get_count("video", strict=False) == 0

412
    def _get_prompt_updates(
413
414
415
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
416
        out_mm_kwargs: MultiModalKwargsItems,
417
418
    ) -> Sequence[PromptUpdate]:
        image_repls = super()._get_prompt_updates(
419
420
421
422
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
423

424
        hf_config = self.info.get_hf_config()
425
426
427
428
        video_token_id = hf_config.video_token_index

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
429
430
                "video", (VideoEmbeddingItems, VideoProcessorItems)
            )
431
432
433
434
435

            if isinstance(videos, VideoEmbeddingItems):
                num_video_tokens = videos.get_feature_size(item_idx)
            else:
                image_size = videos.get_frame_size(item_idx)
436
                num_video_tokens = self.info.get_num_video_tokens(
437
438
439
440
441
442
443
                    image_width=image_size.width,
                    image_height=image_size.height,
                    num_frames=videos.get_num_frames(item_idx),
                )

            return [video_token_id] * num_video_tokens

444
445
        return [
            *image_repls,
446
447
448
449
450
451
452
            PromptReplacement(
                modality="video",
                target=[video_token_id],
                replacement=get_video_replacement,
            ),
        ]

453
454
455
456
457

class LlavaOnevisionMultiModalProjector(nn.Module):
    def __init__(self, config: LlavaOnevisionConfig):
        super().__init__()

458
459
460
461
462
        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.hidden_size,
            bias=config.multimodal_projector_bias,
        )
463
        self.act = get_act_fn(config.projector_hidden_act)
464
465
466
467
468
        self.linear_2 = nn.Linear(
            config.text_config.hidden_size,
            config.text_config.hidden_size,
            bias=config.multimodal_projector_bias,
        )
469
470
471
472
473
474
475
476

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


477
478
479
@MULTIMODAL_REGISTRY.register_processor(
    LlavaOnevisionMultiModalProcessor,
    info=LlavaOnevisionProcessingInfo,
480
481
482
    dummy_inputs=LlavaOnevisionDummyInputsBuilder,
)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
483
484
485
486
487
488
489
490
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
491
492
        }
    )
493

494
495
496
497
498
499
500
501
502
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

503
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
504
        super().__init__()
505
506
507
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
508
509
510
511
512

        self.config = config
        self.multimodal_config = multimodal_config

        # Initialize the vision tower only up to the required feature layer
513
        self.vision_tower = init_vision_tower_for_llava(
514
515
516
            config,
            quant_config,
            require_post_norm=False,
517
518
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
519
520
        self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
521
            vllm_config=vllm_config,
522
523
524
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
525
        self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
526

527
        self.make_empty_intermediate_tensors = (
528
529
            self.language_model.model.make_empty_intermediate_tensors
        )
530

531
    def _parse_and_validate_image_input(
532
533
        self, **kwargs: object
    ) -> Optional[LlavaOnevisionImageInputs]:
534
535
536
537
538
539
540
541
542
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
543
544
545
                raise ValueError(
                    f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
                )
546
547

            if not isinstance(image_sizes, (torch.Tensor, list)):
548
549
550
                raise ValueError(
                    f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
                )
551
552
553

            return LlavaOnevisionImagePixelInputs(
                type="pixel_values",
554
555
556
557
                pixel_values=flatten_bn(pixel_values),
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": self.config.vision_config.image_size,
558
559
560
                    "w": self.config.vision_config.image_size,
                },
            )
561
562
563

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
564
565
566
                raise ValueError(
                    f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
                )
567
568
569
570
571
572
573
574
575

            return LlavaOnevisionImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )

        raise AssertionError("This line should be unreachable.")

    def _parse_and_validate_video_input(
576
577
        self, **kwargs: object
    ) -> Optional[LlavaOnevisionVideoPixelInputs]:
578
579
580
        """
        A legal video input should have the following dimensions:
        {
581
            "pixel_values_videos" :
582
                list[b, Tensor(nb_frames, nb_channels, height, width)]
583
584
        }
        """
585
586
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        if pixel_values_videos is None:
587
588
            return None

589
        if not isinstance(pixel_values_videos, (torch.Tensor, list)):
590
591
592
593
            raise ValueError(
                "Incorrect type of pixel_values_videos. "
                f"Got type: {type(pixel_values_videos)}"
            )
594
595
596

        return LlavaOnevisionVideoPixelInputs(
            type="pixel_values_videos",
597
            pixel_values_videos=flatten_bn(pixel_values_videos),
598
599
            resolve_bindings={
                "h": self.config.vision_config.image_size,
600
601
602
                "w": self.config.vision_config.image_size,
            },
        )
603
604

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
605
        mm_input_by_modality = {}
606

607
608
609
        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
624

625
        return mm_input_by_modality
626
627
628
629
630
631
632
633

    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
634
635
636
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
637
638
639
        )

    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
640
641
642
643
644
645
646
647
648
    def _merge_image_patch_embeddings(
        self,
        image_size: torch.Tensor,
        patch_embeddings: torch.Tensor,
        *,
        image_newline=None,
        vision_aspect_ratio="anyres_max_9",
        strategy: str,
    ) -> torch.Tensor:
649
650
651
652
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
653
654
            height = width = (
                self.config.vision_config.image_size
655
                // self.config.vision_config.patch_size
656
            )
657
658
659
660

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
661
662
                    "The number of patches is not consistent with the image size."
                )
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

                # image_aspect_ratio == "anyres"
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
679
680
681
                other_patch_embeds = other_patch_embeds[:num_patches].view(
                    num_patch_height, num_patch_width, height, width, -1
                )
682
683

                if "unpad" in strategy:
684
685
686
687
688
689
690
691
692
                    other_patch_embeds = (
                        other_patch_embeds.permute(4, 0, 2, 1, 3)
                        .contiguous()
                        .flatten(1, 2)
                        .flatten(2, 3)
                    )
                    other_patch_embeds = unpad_image(
                        other_patch_embeds, (orig_height, orig_width)
                    )
693
                    max_num_patches = int(
694
695
                        vision_aspect_ratio.removeprefix("anyres_max_")
                    )
696
                    channels, curr_height, curr_width = other_patch_embeds.shape
697
698
699
                    ratio = math.sqrt(
                        curr_height * curr_width / (max_num_patches * height**2)
                    )
700
701
702
                    if ratio > 1.1:
                        other_patch_embeds = other_patch_embeds[None]
                        other_patch_embeds = nn.functional.interpolate(
703
704
705
706
                            other_patch_embeds,
                            [int(curr_height // ratio), int(curr_width // ratio)],
                            mode="bilinear",
                        )[0]
707
708
709
710
                    if image_newline is not None:
                        other_patch_embeds = torch.cat(
                            (
                                other_patch_embeds,
711
712
                                image_newline[:, None, None]
                                .expand(*other_patch_embeds.shape[:-1], 1)
713
714
                                .to(other_patch_embeds.device),
                            ),
715
716
717
718
719
                            dim=-1,
                        )
                    other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
                        0, 1
                    )
720
                else:
721
722
723
                    other_patch_embeds = (
                        other_patch_embeds.permute(0, 2, 1, 3, 4)
                        .contiguous()
724
                        .flatten(0, 3)
725
                    )
726
727

                merged_patch_embeddings = torch.cat(
728
729
                    (base_patch_embeds, other_patch_embeds), dim=0
                )
730
731
732
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
733
734
735
736
737
738
                        (
                            base_patch_embeds,
                            self.image_newline[None].to(base_patch_embeds.device),
                        ),
                        dim=0,
                    )
739
740
741
742
743
744
745
746
747
748
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
        self,
        inputs: LlavaOnevisionImagePixelInputs,
749
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
750
751
        assert self.vision_tower is not None

752
        pixel_values = inputs["pixel_values"]
753
754
755
756
757

        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
758
759
                self.vision_tower, stacked_pixel_values
            )
760
            stacked_patch_embeddings = self.multi_modal_projector(
761
762
                stacked_image_features
            )
763
764

            return stacked_patch_embeddings.view(
765
766
                b, num_patches, *stacked_patch_embeddings.shape[1:]
            )
767
768
769
770

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
        stacked_image_features = self._image_pixels_to_features(
771
772
            self.vision_tower, stacked_pixel_values
        )
773
774

        return [
775
776
777
778
            self.multi_modal_projector(image_features)
            for image_features in torch.split(
                stacked_image_features, num_patches_per_batch
            )
779
780
781
782
783
        ]

    def _process_image_input(
        self,
        image_input: LlavaOnevisionImageInputs,
784
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
785
786
787
788
789
790
791
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

        patch_embeddings = self._process_image_pixels(image_input)

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
792
            batch_size = len(image_input["pixel_values"])
793
794
            vision_config = self.config.vision_config
            default_height = default_width = vision_config.image_size
795
796
797
            image_sizes = torch.as_tensor(
                [[default_height, default_width] for _ in range(batch_size)]
            )
798
799
800
801
802
803

        return [
            self._merge_image_patch_embeddings(
                image_sizes[i],
                patch_features_batch,
                image_newline=self.image_newline,
804
805
                strategy="spatial_unpad",
            )
806
807
808
809
810
811
812
813
814
815
            for i, patch_features_batch in enumerate(patch_embeddings)
        ]

    def _video_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
816
817
818
        video_features = vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
819
820
821
822
823
824
825
826
        )
        video_features = self.multi_modal_projector(video_features)
        video_features = self.apply_pooling(video_features)
        return video_features

    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
        assert self.vision_tower is not None

827
        video_pixels = inputs["pixel_values_videos"]
828
829

        if isinstance(video_pixels, torch.Tensor):
830
            total_videos, frames, c, h, w = video_pixels.shape
831
            video_pixels_flat = video_pixels.view(total_videos * frames, c, h, w)
832
833

            embeddings_flat = self._video_pixels_to_features(
834
835
                self.vision_tower, video_pixels_flat
            )
836
837

            embeddings_flat = embeddings_flat.reshape(
838
839
                total_videos, frames * embeddings_flat.shape[1], -1
            )
840
841

            image_newline = self.image_newline[None, None, :].expand(
842
843
                total_videos, -1, -1
            )
844
845
846
847
848
849
            return torch.cat((embeddings_flat, image_newline), dim=1)

        frames_per_video = [len(video) for video in video_pixels]
        video_pixels_flat = torch.cat(video_pixels)

        embeddings_flat = self._video_pixels_to_features(
850
851
            self.vision_tower, video_pixels_flat
        )
852
853
854
855
856
857

        image_newline = self.image_newline[None, None, :]

        return [
            torch.cat(
                (
858
                    embeds.reshape(1, num_frame * embeddings_flat.shape[1], -1),
859
860
861
                    image_newline,
                ),
                dim=1,
862
863
            )
            for num_frame, embeds in zip(
864
865
866
867
                frames_per_video,
                torch.split(embeddings_flat, frames_per_video),
            )
        ]
868

869
    def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
870
871
872
873
874
875
876
877
878
        vision_config = self.config.vision_config
        height = width = vision_config.image_size // vision_config.patch_size
        batch_frames, _, dim = image_features.shape
        image_features = image_features.view(batch_frames, height, width, -1)
        image_features = image_features.permute(0, 3, 1, 2)

        # TODO support other pooling types config
        height, width = image_features.shape[2:]
        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
879
880
881
        image_feature = nn.functional.interpolate(
            image_features, size=scaled_shape, mode="bilinear"
        )
882
883
884
885
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(batch_frames, -1, dim)
        return image_feature

886
887
888
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

889
890
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
891
        if not mm_input_by_modality:
892
            return []
893
894
            return None

895
        # The result multimodal_embeddings is tuple of tensors, with each
896
        # tensor corresponding to a multimodal data item (image or video).
897
898
899
900
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
901
902
903
904
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
905
                multimodal_embeddings += tuple(vision_embeddings)
906
907
            if modality == "video":
                video_embeddings = self._process_video_pixels(multimodal_input)
908
                multimodal_embeddings += tuple(video_embeddings)
909
910
911

        return multimodal_embeddings

912
913
914
915
916
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
917
        inputs_embeds: Optional[torch.Tensor] = None,
918
        **kwargs: object,
919
    ) -> Union[torch.Tensor, IntermediateTensors]:
920
921
922
923
924
925
        """Run forward pass for LlaVA-Onevision.
        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            pixel_values_videos: Pixels in each frames for each input videos.
        """
926
        if intermediate_tensors is not None:
927
            inputs_embeds = None
928

929
930
931
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
932
933
934
935
936
937
938

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
939
        return self.language_model.compute_logits(hidden_states)
940

941
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
942
        loader = AutoWeightsLoader(self)
943
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)