llava_onevision.py 36.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
7
8
9

import torch
import torch.nn as nn
10
11
from transformers import (BatchFeature, LlavaOnevisionConfig,
                          LlavaOnevisionProcessor)
12
13
14
15
from transformers.models.llava_onevision.modeling_llava_onevision import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

16
from vllm.config import VllmConfig
17
18
19
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
20
21
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs)
22
23
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
                                   VideoEmbeddingItems, VideoProcessorItems)
24
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
25
26
from vllm.sequence import IntermediateTensors

27
from .clip import CLIPVisionModel
28
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
29
30
31
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
                         LlavaNextProcessingInfo)
32
from .siglip import SiglipVisionModel
33
34
35
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
36

37
38
39
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

40
41
42

class LlavaOnevisionVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
43
    pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
44
    """
45
    Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
46

47
48
49
    Note that `num_videos` may be different for each batch, and 'num_frames'
    may be different for each video, in which case the data is passed as a
    list instead of a batched tensor.
50
51
52
53
54
    """


class LlavaOnevisionImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
55
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`

    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    image_sizes: NotRequired[torch.Tensor]
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(height, width)` format.
    """


class LlavaOnevisionImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
                                  LlavaOnevisionImageEmbeddingInputs]

LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
                                  LlavaOnevisionVideoPixelInputs]


88
89
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
    video_token_index: Final[int]
90
91


92
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
93

94
    def get_hf_config(self) -> LlavaOnevisionLikeConfig:
95
        return self.ctx.get_hf_config(LlavaOnevisionConfig)
96

97
98
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs)
99

100
101
102
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

103
104
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
    # with additional logic afterwards taken from LlavaOnevisionProcessor
105
106
107
108
109
110
111
112
113
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
114
115
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
116

117
118
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
119

120
        if aspect_ratio > current_aspect_ratio:
121
122
            new_height = int(
                round(original_height * (current_width / original_width), 7))
123
124
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
125
        else:
126
127
            new_width = int(
                round(original_width * (current_height / original_height), 7))
128
129
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
130

131
132
        unpadded_features = current_height * current_width
        newline_features = current_height
133

134
        ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
135
        if ratio > 1.1:
136
137
            height_factor = int(current_height // ratio)
            width_factor = int(current_width // ratio)
138
139
            unpadded_features = height_factor * width_factor
            newline_features = height_factor
140
141
142

        return (unpadded_features, newline_features)

143
144
145
146
    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE: This hardcoded value is found via processor tests
        return ImageSize(width=1153, height=944)

147
148
149
150
151
152
    def _get_num_frame_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
153
        hf_config = self.get_hf_config()
154
155
        spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

156
        vision_encoder_info = self.get_vision_encoder_info()
157
        patch_grid_length = vision_encoder_info.get_patch_grid_length()
158
159
160
161
        pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)

        return pooled_grid_length * pooled_grid_length

162
    def get_num_video_tokens(
163
164
165
166
167
168
169
170
171
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        num_frame_tokens = self._get_num_frame_tokens(
            image_width=image_width,
            image_height=image_height,
172
173
        )

174
        return num_frame_tokens * num_frames + 1  # Newline token
175

176
    def _get_max_video_frames(self, max_tokens: int) -> int:
177
        target_width, target_height = self.get_image_size_with_most_features()
178

179
        num_frames = 0
180

181
182
        while True:
            next_num_frames = num_frames + 1
183
            next_max_tokens = self.get_num_video_tokens(
184
185
186
187
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
188

189
            if next_max_tokens > max_tokens:
190
                break
191

192
            num_frames = next_num_frames
193

194
195
        return num_frames

196
197
198
199
200
201
202
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
203

204
        max_image_tokens = self.get_max_image_tokens() * max_images
205
206
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
207
208
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
209

210
        return max(max_frames_per_video, 1)
211

212
213
214
215
216
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
217
        target_width, target_height = self.get_image_size_with_most_features()
218

219
        return self.get_num_video_tokens(
220
221
            image_width=target_width,
            image_height=target_height,
222
223
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
224
225
        )

226
227
228
229

class LlavaOnevisionDummyInputsBuilder(
        LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):

230
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
231
232
233
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

234
        processor = self.info.get_hf_processor()
235
236
        image_token = processor.image_token
        video_token = processor.video_token
237

238
239
240
241
242
243
244
245
246
247
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

248
249
250
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
251
252
            self.info.get_num_frames_with_most_features(seq_len,
                                                        mm_counts)
253

254
        return {
255
256
257
258
259
260
261
262
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
263
                num_frames=target_num_frames,
264
265
266
267
268
                num_videos=num_videos,
            )
        }


269
270
class LlavaOnevisionMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
271
272
273
274
275
276
277
278
279
280
281
282

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
        )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        assert isinstance(videos, list)

        if not videos:
            return super()._call_hf_processor(
                prompt=prompt,
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
            )

301
302
303
304
        # LLaVA-OneVision processor doesn't support multiple videos
        # with different sizes when converting back to tensors
        # So, we process each component separately
        # NOTE: No prompt replacement is applied in this case
305
        processor = self.info.get_hf_processor()
306
        image_token = processor.image_token
307
        video_token = processor.video_token
308

309
        text_outputs = super()._call_hf_processor(
310
            prompt=prompt,
311
            mm_data={},
312
            mm_kwargs=mm_kwargs,
313
314
        )

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        images = mm_data.pop("images", [])
        assert isinstance(images, list)
        if images:
            processor_outputs = super()._call_hf_processor(
                prompt=image_token * len(images),
                mm_data={"images": images},
                mm_kwargs=mm_kwargs,
            )
            image_outputs = {
                k: v
                for k, v in processor_outputs.items()
                if k in ("pixel_values", "image_sizes")
            }
        else:
            image_outputs = {}

331
332
333
        pixel_values_videos = []
        for video in videos:
            item_outputs = super()._call_hf_processor(
334
335
                prompt=video_token,
                mm_data={"videos": video},
336
337
                mm_kwargs=mm_kwargs,
            )
338

339
340
341
            pixel_values_videos.append(item_outputs["pixel_values_videos"][0])

        video_outputs = {"pixel_values_videos": pixel_values_videos}
342

343
        combined_outputs = dict(
344
345
346
            text_outputs,
            **image_outputs,
            **video_outputs,
347
348
349
        )
        return BatchFeature(combined_outputs)

350
    def _hf_processor_applies_updates(
351
352
353
354
355
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
356
        base_result = super()._hf_processor_applies_updates(
357
358
359
360
361
362
363
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return base_result and mm_items.get_count("video", strict=False) == 0

364
    def _get_prompt_updates(
365
366
367
368
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
369
370
    ) -> Sequence[PromptUpdate]:
        image_repls = super()._get_prompt_updates(
371
372
373
374
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
375

376
        hf_config = self.info.get_hf_config()
377
378
379
380
381
382
383
384
385
386
        video_token_id = hf_config.video_token_index

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
                "video", (VideoEmbeddingItems, VideoProcessorItems))

            if isinstance(videos, VideoEmbeddingItems):
                num_video_tokens = videos.get_feature_size(item_idx)
            else:
                image_size = videos.get_frame_size(item_idx)
387
                num_video_tokens = self.info.get_num_video_tokens(
388
389
390
391
392
393
394
                    image_width=image_size.width,
                    image_height=image_size.height,
                    num_frames=videos.get_num_frames(item_idx),
                )

            return [video_token_id] * num_video_tokens

395
396
        return [
            *image_repls,
397
398
399
400
401
402
403
            PromptReplacement(
                modality="video",
                target=[video_token_id],
                replacement=get_video_replacement,
            ),
        ]

404
405
406
407
408
409
410
411

class LlavaOnevisionMultiModalProjector(nn.Module):

    def __init__(self, config: LlavaOnevisionConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.vision_config.hidden_size,
                                  config.text_config.hidden_size,
412
                                  bias=config.multimodal_projector_bias)
413
414
415
        self.act = get_act_fn(config.projector_hidden_act)
        self.linear_2 = nn.Linear(config.text_config.hidden_size,
                                  config.text_config.hidden_size,
416
                                  bias=config.multimodal_projector_bias)
417
418
419
420
421
422
423
424

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


425
426
427
428
@MULTIMODAL_REGISTRY.register_processor(
    LlavaOnevisionMultiModalProcessor,
    info=LlavaOnevisionProcessingInfo,
    dummy_inputs=LlavaOnevisionDummyInputsBuilder)
429
430
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
                                             SupportsPP):
431

432
433
434
435
436
437
438
439
440
441
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
        })

442
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
443
        super().__init__()
444
445
446
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
447
448
449
450
451

        self.config = config
        self.multimodal_config = multimodal_config

        # Initialize the vision tower only up to the required feature layer
452
        self.vision_tower = init_vision_tower_for_llava(
453
454
455
            config,
            quant_config,
            require_post_norm=False,
456
            prefix=maybe_prefix(prefix, "vision_tower"))
457
458
        self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
459
            vllm_config=vllm_config,
460
461
462
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
463
464
465
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

466
467
468
        self.make_empty_intermediate_tensors = (
            self.language_model.model.make_empty_intermediate_tensors)

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _validate_image_pixel_values(
487
488
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            if not isinstance(image_sizes, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return LlavaOnevisionImagePixelInputs(
                type="pixel_values",
527
                pixel_values=self._validate_image_pixel_values(
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
                    flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaOnevisionImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )

        raise AssertionError("This line should be unreachable.")

    def _validate_video_pixel_values(
546
547
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[2:])

            if actual_dims != expected_dims:
                expected_expr = ("num_frames", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values in each video frame "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_video_input(
            self,
            **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
        """
        A legal video input should have the following dimensions:
        {
            "pixel_values_videos" : 
573
                list[b, Tensor(nb_frames, nb_channels, height, width)]
574
575
        }
        """
576
577
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        if pixel_values_videos is None:
578
579
            return None

580
        if not isinstance(pixel_values_videos, (torch.Tensor, list)):
581
582
            raise ValueError("Incorrect type of pixel_values_videos. "
                             f"Got type: {type(pixel_values_videos)}")
583
584
585

        return LlavaOnevisionVideoPixelInputs(
            type="pixel_values_videos",
586
            pixel_values_videos=flatten_bn(pixel_values_videos),
587
588
589
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
590
        mm_input_by_modality = {}
591

592
593
594
        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
595
596
597
598
599
600
601
602
            if input_key in ("pixel_values", "image_embeds"
                             ) and "image" not in mm_input_by_modality:
                mm_input_by_modality[
                    "image"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_videos", "video_embeds"
                             ) and "video" not in mm_input_by_modality:
                mm_input_by_modality[
                    "video"] = self._parse_and_validate_video_input(**kwargs)
603

604
        return mm_input_by_modality
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        image_features = vision_tower(pixel_values)
        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
    def _merge_image_patch_embeddings(self,
                                      image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor,
                                      *,
                                      image_newline=None,
                                      vision_aspect_ratio="anyres_max_9",
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

                # image_aspect_ratio == "anyres"
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
                    .view(num_patch_height, num_patch_width, height, width, -1)

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
                                                     (orig_height, orig_width))
                    max_num_patches = int(
                        vision_aspect_ratio.removeprefix("anyres_max_"))
                    channels, curr_height, curr_width = other_patch_embeds.shape
                    ratio = math.sqrt(curr_height * curr_width /
                                      (max_num_patches * height**2))
                    if ratio > 1.1:
                        other_patch_embeds = other_patch_embeds[None]
                        other_patch_embeds = nn.functional.interpolate(
                            other_patch_embeds, [
                                int(curr_height // ratio),
                                int(curr_width // ratio)
                            ],
                            mode="bilinear")[0]
                    if image_newline is not None:
                        other_patch_embeds = torch.cat(
                            (
                                other_patch_embeds,
                                image_newline[:, None, None] \
                                .expand(*other_patch_embeds.shape[:-1], 1) \
                                .to(other_patch_embeds.device),
                            ),
                        dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
        self,
        inputs: LlavaOnevisionImagePixelInputs,
722
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
723
724
        assert self.vision_tower is not None

725
        pixel_values = inputs["pixel_values"]
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)

            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]

    def _process_image_input(
        self,
        image_input: LlavaOnevisionImageInputs,
751
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
752
753
754
755
756
757
758
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

        patch_embeddings = self._process_image_pixels(image_input)

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
759
            batch_size = len(image_input["pixel_values"])
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
            vision_config = self.config.vision_config
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
                                           for _ in range(batch_size)])

        return [
            self._merge_image_patch_embeddings(
                image_sizes[i],
                patch_features_batch,
                image_newline=self.image_newline,
                strategy="spatial_unpad")
            for i, patch_features_batch in enumerate(patch_embeddings)
        ]

    def _video_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        video_features = vision_tower(pixel_values)
        video_features = self._select_image_features(
            video_features,
            strategy=self.config.vision_feature_select_strategy,
        )
        video_features = self.multi_modal_projector(video_features)
        video_features = self.apply_pooling(video_features)
        return video_features

    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
        assert self.vision_tower is not None

794
        video_pixels = inputs["pixel_values_videos"]
795
796

        if isinstance(video_pixels, torch.Tensor):
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            total_videos, frames, c, h, w = video_pixels.shape
            video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
                                                  w)

            embeddings_flat = self._video_pixels_to_features(
                self.vision_tower, video_pixels_flat)

            embeddings_flat = embeddings_flat.reshape(
                total_videos, frames * embeddings_flat.shape[1], -1)

            image_newline = self.image_newline[None, None, :].expand(
                total_videos, -1, -1)
            return torch.cat((embeddings_flat, image_newline), dim=1)

        frames_per_video = [len(video) for video in video_pixels]
        video_pixels_flat = torch.cat(video_pixels)

        embeddings_flat = self._video_pixels_to_features(
            self.vision_tower, video_pixels_flat)

        image_newline = self.image_newline[None, None, :]

        return [
            torch.cat(
                (
                    embeds.reshape(1, num_frame * embeddings_flat.shape[1],
                                   -1),
                    image_newline,
                ),
                dim=1,
            ) for num_frame, embeds in zip(
                frames_per_video,
                torch.split(embeddings_flat, frames_per_video),
            )
        ]
832

833
    def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        vision_config = self.config.vision_config
        height = width = vision_config.image_size // vision_config.patch_size
        batch_frames, _, dim = image_features.shape
        image_features = image_features.view(batch_frames, height, width, -1)
        image_features = image_features.permute(0, 3, 1, 2)

        # TODO support other pooling types config
        height, width = image_features.shape[2:]
        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
        image_feature = nn.functional.interpolate(image_features,
                                                  size=scaled_shape,
                                                  mode='bilinear')
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(batch_frames, -1, dim)
        return image_feature

850
851
852
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

853
854
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
855
856
857
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
            **kwargs)
        if not mm_input_by_modality:
858
            return []
859
860
            return None

861
862
863
864
865
866
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
867
868
869
870
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
871
                multimodal_embeddings += tuple(vision_embeddings)
872
873
            if modality == "video":
                video_embeddings = self._process_video_pixels(multimodal_input)
874
                multimodal_embeddings += tuple(video_embeddings)
875
876
877
878
879
880

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
881
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
882
883
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
884
885
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
886
887
888
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_index, self.config.video_token_index])
889
890
        return inputs_embeds

891
892
893
    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
894
895
        image_input: Optional[LlavaOnevisionImagePixelInputs] = None,
        video_input: Optional[LlavaOnevisionVideoPixelInputs] = None,
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_index,
            )

        if video_input is not None:
            video_embeds = self._process_video_pixels(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_index,
            )

        return inputs_embeds

918
919
920
921
922
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
923
        inputs_embeds: Optional[torch.Tensor] = None,
924
        **kwargs: object,
925
    ) -> Union[torch.Tensor, IntermediateTensors]:
926
927
928
929
930
931
        """Run forward pass for LlaVA-Onevision.
        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            pixel_values_videos: Pixels in each frames for each input videos.
        """
932
        if intermediate_tensors is not None:
933
            inputs_embeds = None
934

935
936
937
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
938
        elif inputs_embeds is None:
939
940
941
942
943
944
945
946
947
948
949
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None
950
951
952

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
953
                                                  intermediate_tensors,
954
955
956
957
958
959
960
961
962
963
964
965
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

966
967
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
968
        loader = AutoWeightsLoader(self)
969
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)