llava_onevision.py 36.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
7
8
9

import torch
import torch.nn as nn
10
11
from transformers import (BatchFeature, LlavaOnevisionConfig,
                          LlavaOnevisionProcessor)
12
13
14
15
from transformers.models.llava_onevision.modeling_llava_onevision import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

16
from vllm.config import VllmConfig
17
18
19
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
20
21
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs)
22
23
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
                                   VideoEmbeddingItems, VideoProcessorItems)
24
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
25
26
from vllm.sequence import IntermediateTensors

27
from .clip import CLIPVisionModel
28
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
29
30
31
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
                         LlavaNextProcessingInfo)
32
from .siglip import SiglipVisionModel
33
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
34
                    maybe_prefix, merge_multimodal_embeddings)
35

36
37
38
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

39
40
41

class LlavaOnevisionVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
42
    pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
43
    """
44
    Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
45

46
47
48
    Note that `num_videos` may be different for each batch, and 'num_frames'
    may be different for each video, in which case the data is passed as a
    list instead of a batched tensor.
49
50
51
52
53
    """


class LlavaOnevisionImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
54
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`

    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    image_sizes: NotRequired[torch.Tensor]
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(height, width)` format.
    """


class LlavaOnevisionImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
                                  LlavaOnevisionImageEmbeddingInputs]

LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
                                  LlavaOnevisionVideoPixelInputs]


87
88
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
    video_token_index: Final[int]
89
90


91
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
92

93
    def get_hf_config(self) -> LlavaOnevisionLikeConfig:
94
        return self.ctx.get_hf_config(LlavaOnevisionConfig)
95

96
97
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs)
98

99
100
101
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

102
103
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
    # with additional logic afterwards taken from LlavaOnevisionProcessor
104
105
106
107
108
109
110
111
112
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
113
114
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
115

116
117
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
118

119
        if aspect_ratio > current_aspect_ratio:
120
121
            new_height = int(
                round(original_height * (current_width / original_width), 7))
122
123
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
124
        else:
125
126
            new_width = int(
                round(original_width * (current_height / original_height), 7))
127
128
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
129

130
131
        unpadded_features = current_height * current_width
        newline_features = current_height
132

133
        ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
134
        if ratio > 1.1:
135
136
            height_factor = int(current_height // ratio)
            width_factor = int(current_width // ratio)
137
138
            unpadded_features = height_factor * width_factor
            newline_features = height_factor
139
140
141

        return (unpadded_features, newline_features)

142
143
144
145
    def get_image_size_with_most_features(self) -> ImageSize:
        # NOTE: This hardcoded value is found via processor tests
        return ImageSize(width=1153, height=944)

146
147
148
149
150
151
    def _get_num_frame_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
152
        hf_config = self.get_hf_config()
153
154
        spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

155
        vision_encoder_info = self.get_vision_encoder_info()
156
        patch_grid_length = vision_encoder_info.get_patch_grid_length()
157
158
159
160
        pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)

        return pooled_grid_length * pooled_grid_length

161
    def get_num_video_tokens(
162
163
164
165
166
167
168
169
170
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        num_frame_tokens = self._get_num_frame_tokens(
            image_width=image_width,
            image_height=image_height,
171
172
        )

173
        return num_frame_tokens * num_frames + 1  # Newline token
174

175
    def _get_max_video_frames(self, max_tokens: int) -> int:
176
        target_width, target_height = self.get_image_size_with_most_features()
177

178
        num_frames = 0
179

180
181
        while True:
            next_num_frames = num_frames + 1
182
            next_max_tokens = self.get_num_video_tokens(
183
184
185
186
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
187

188
            if next_max_tokens > max_tokens:
189
                break
190

191
            num_frames = next_num_frames
192

193
194
        return num_frames

195
196
197
198
199
200
201
    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)
202

203
        max_image_tokens = self.get_max_image_tokens() * max_images
204
205
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
206
207
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
208

209
        return max(max_frames_per_video, 1)
210

211
212
213
214
215
    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
216
        target_width, target_height = self.get_image_size_with_most_features()
217

218
        return self.get_num_video_tokens(
219
220
            image_width=target_width,
            image_height=target_height,
221
222
            num_frames=self.get_num_frames_with_most_features(
                seq_len, mm_counts),
223
224
        )

225
226
227
228

class LlavaOnevisionDummyInputsBuilder(
        LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):

229
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
230
231
232
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

233
        processor = self.info.get_hf_processor()
234
235
        image_token = processor.image_token
        video_token = processor.video_token
236

237
238
239
240
241
242
243
244
245
246
        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

247
248
249
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = \
250
251
            self.info.get_num_frames_with_most_features(seq_len,
                                                        mm_counts)
252

253
        return {
254
255
256
257
258
259
260
261
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
262
                num_frames=target_num_frames,
263
264
265
266
267
                num_videos=num_videos,
            )
        }


268
269
class LlavaOnevisionMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
270
271
272
273
274
275
276
277
278
279
280
281

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
        )
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        assert isinstance(videos, list)

        if not videos:
            return super()._call_hf_processor(
                prompt=prompt,
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
            )

300
301
302
303
        # LLaVA-OneVision processor doesn't support multiple videos
        # with different sizes when converting back to tensors
        # So, we process each component separately
        # NOTE: No prompt replacement is applied in this case
304
        processor = self.info.get_hf_processor()
305
        image_token = processor.image_token
306
        video_token = processor.video_token
307

308
        text_outputs = super()._call_hf_processor(
309
            prompt=prompt,
310
            mm_data={},
311
            mm_kwargs=mm_kwargs,
312
313
        )

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        images = mm_data.pop("images", [])
        assert isinstance(images, list)
        if images:
            processor_outputs = super()._call_hf_processor(
                prompt=image_token * len(images),
                mm_data={"images": images},
                mm_kwargs=mm_kwargs,
            )
            image_outputs = {
                k: v
                for k, v in processor_outputs.items()
                if k in ("pixel_values", "image_sizes")
            }
        else:
            image_outputs = {}

330
331
332
        pixel_values_videos = []
        for video in videos:
            item_outputs = super()._call_hf_processor(
333
334
                prompt=video_token,
                mm_data={"videos": video},
335
336
                mm_kwargs=mm_kwargs,
            )
337

338
339
340
            pixel_values_videos.append(item_outputs["pixel_values_videos"][0])

        video_outputs = {"pixel_values_videos": pixel_values_videos}
341

342
        combined_outputs = dict(
343
344
345
            text_outputs,
            **image_outputs,
            **video_outputs,
346
347
348
        )
        return BatchFeature(combined_outputs)

349
    def _hf_processor_applies_updates(
350
351
352
353
354
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> bool:
355
        base_result = super()._hf_processor_applies_updates(
356
357
358
359
360
361
362
            prompt_text=prompt_text,
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return base_result and mm_items.get_count("video", strict=False) == 0

363
    def _get_prompt_updates(
364
365
366
367
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
368
369
    ) -> Sequence[PromptUpdate]:
        image_repls = super()._get_prompt_updates(
370
371
372
373
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
374

375
        hf_config = self.info.get_hf_config()
376
377
378
379
380
381
382
383
384
385
        video_token_id = hf_config.video_token_index

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
                "video", (VideoEmbeddingItems, VideoProcessorItems))

            if isinstance(videos, VideoEmbeddingItems):
                num_video_tokens = videos.get_feature_size(item_idx)
            else:
                image_size = videos.get_frame_size(item_idx)
386
                num_video_tokens = self.info.get_num_video_tokens(
387
388
389
390
391
392
393
                    image_width=image_size.width,
                    image_height=image_size.height,
                    num_frames=videos.get_num_frames(item_idx),
                )

            return [video_token_id] * num_video_tokens

394
395
        return [
            *image_repls,
396
397
398
399
400
401
402
            PromptReplacement(
                modality="video",
                target=[video_token_id],
                replacement=get_video_replacement,
            ),
        ]

403
404
405
406
407
408
409
410

class LlavaOnevisionMultiModalProjector(nn.Module):

    def __init__(self, config: LlavaOnevisionConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.vision_config.hidden_size,
                                  config.text_config.hidden_size,
411
                                  bias=config.multimodal_projector_bias)
412
413
414
        self.act = get_act_fn(config.projector_hidden_act)
        self.linear_2 = nn.Linear(config.text_config.hidden_size,
                                  config.text_config.hidden_size,
415
                                  bias=config.multimodal_projector_bias)
416
417
418
419
420
421
422
423

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


424
425
426
427
@MULTIMODAL_REGISTRY.register_processor(
    LlavaOnevisionMultiModalProcessor,
    info=LlavaOnevisionProcessingInfo,
    dummy_inputs=LlavaOnevisionDummyInputsBuilder)
428
429
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
                                             SupportsPP):
430

431
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
432
        super().__init__()
433
434
435
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
436
437
438
439
440

        self.config = config
        self.multimodal_config = multimodal_config

        # Initialize the vision tower only up to the required feature layer
441
        self.vision_tower = init_vision_tower_for_llava(
442
443
444
            config,
            quant_config,
            require_post_norm=False,
445
            prefix=maybe_prefix(prefix, "vision_tower"))
446
447
        self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
448
            vllm_config=vllm_config,
449
450
451
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
452
453
454
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

455
456
457
        self.make_empty_intermediate_tensors = (
            self.language_model.model.make_empty_intermediate_tensors)

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _validate_image_pixel_values(
476
477
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            if not isinstance(image_sizes, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return LlavaOnevisionImagePixelInputs(
                type="pixel_values",
516
                pixel_values=self._validate_image_pixel_values(
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                    flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaOnevisionImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )

        raise AssertionError("This line should be unreachable.")

    def _validate_video_pixel_values(
535
536
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[2:])

            if actual_dims != expected_dims:
                expected_expr = ("num_frames", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values in each video frame "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_video_input(
            self,
            **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
        """
        A legal video input should have the following dimensions:
        {
            "pixel_values_videos" : 
562
                list[b, Tensor(nb_frames, nb_channels, height, width)]
563
564
        }
        """
565
566
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        if pixel_values_videos is None:
567
568
            return None

569
        if not isinstance(pixel_values_videos, (torch.Tensor, list)):
570
571
            raise ValueError("Incorrect type of pixel_values_videos. "
                             f"Got type: {type(pixel_values_videos)}")
572
573
574

        return LlavaOnevisionVideoPixelInputs(
            type="pixel_values_videos",
575
            pixel_values_videos=flatten_bn(pixel_values_videos),
576
577
578
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
579
        mm_input_by_modality = {}
580

581
582
583
        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
584
585
586
587
588
589
590
591
            if input_key in ("pixel_values", "image_embeds"
                             ) and "image" not in mm_input_by_modality:
                mm_input_by_modality[
                    "image"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_videos", "video_embeds"
                             ) and "video" not in mm_input_by_modality:
                mm_input_by_modality[
                    "video"] = self._parse_and_validate_video_input(**kwargs)
592

593
        return mm_input_by_modality
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        image_features = vision_tower(pixel_values)
        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
    def _merge_image_patch_embeddings(self,
                                      image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor,
                                      *,
                                      image_newline=None,
                                      vision_aspect_ratio="anyres_max_9",
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

                # image_aspect_ratio == "anyres"
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
                    .view(num_patch_height, num_patch_width, height, width, -1)

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
                                                     (orig_height, orig_width))
                    max_num_patches = int(
                        vision_aspect_ratio.removeprefix("anyres_max_"))
                    channels, curr_height, curr_width = other_patch_embeds.shape
                    ratio = math.sqrt(curr_height * curr_width /
                                      (max_num_patches * height**2))
                    if ratio > 1.1:
                        other_patch_embeds = other_patch_embeds[None]
                        other_patch_embeds = nn.functional.interpolate(
                            other_patch_embeds, [
                                int(curr_height // ratio),
                                int(curr_width // ratio)
                            ],
                            mode="bilinear")[0]
                    if image_newline is not None:
                        other_patch_embeds = torch.cat(
                            (
                                other_patch_embeds,
                                image_newline[:, None, None] \
                                .expand(*other_patch_embeds.shape[:-1], 1) \
                                .to(other_patch_embeds.device),
                            ),
                        dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
        self,
        inputs: LlavaOnevisionImagePixelInputs,
711
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
712
713
        assert self.vision_tower is not None

714
        pixel_values = inputs["pixel_values"]
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739

        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)

            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]

    def _process_image_input(
        self,
        image_input: LlavaOnevisionImageInputs,
740
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
741
742
743
744
745
746
747
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

        patch_embeddings = self._process_image_pixels(image_input)

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
748
            batch_size = len(image_input["pixel_values"])
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
            vision_config = self.config.vision_config
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
                                           for _ in range(batch_size)])

        return [
            self._merge_image_patch_embeddings(
                image_sizes[i],
                patch_features_batch,
                image_newline=self.image_newline,
                strategy="spatial_unpad")
            for i, patch_features_batch in enumerate(patch_embeddings)
        ]

    def _video_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        video_features = vision_tower(pixel_values)
        video_features = self._select_image_features(
            video_features,
            strategy=self.config.vision_feature_select_strategy,
        )
        video_features = self.multi_modal_projector(video_features)
        video_features = self.apply_pooling(video_features)
        return video_features

    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
        assert self.vision_tower is not None

783
        video_pixels = inputs["pixel_values_videos"]
784
785

        if isinstance(video_pixels, torch.Tensor):
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
            total_videos, frames, c, h, w = video_pixels.shape
            video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
                                                  w)

            embeddings_flat = self._video_pixels_to_features(
                self.vision_tower, video_pixels_flat)

            embeddings_flat = embeddings_flat.reshape(
                total_videos, frames * embeddings_flat.shape[1], -1)

            image_newline = self.image_newline[None, None, :].expand(
                total_videos, -1, -1)
            return torch.cat((embeddings_flat, image_newline), dim=1)

        frames_per_video = [len(video) for video in video_pixels]
        video_pixels_flat = torch.cat(video_pixels)

        embeddings_flat = self._video_pixels_to_features(
            self.vision_tower, video_pixels_flat)

        image_newline = self.image_newline[None, None, :]

        return [
            torch.cat(
                (
                    embeds.reshape(1, num_frame * embeddings_flat.shape[1],
                                   -1),
                    image_newline,
                ),
                dim=1,
            ) for num_frame, embeds in zip(
                frames_per_video,
                torch.split(embeddings_flat, frames_per_video),
            )
        ]
821

822
    def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        vision_config = self.config.vision_config
        height = width = vision_config.image_size // vision_config.patch_size
        batch_frames, _, dim = image_features.shape
        image_features = image_features.view(batch_frames, height, width, -1)
        image_features = image_features.permute(0, 3, 1, 2)

        # TODO support other pooling types config
        height, width = image_features.shape[2:]
        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
        image_feature = nn.functional.interpolate(image_features,
                                                  size=scaled_shape,
                                                  mode='bilinear')
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(batch_frames, -1, dim)
        return image_feature

839
840
841
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

842
    def get_multimodal_embeddings(
843
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
844
845
846
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
            **kwargs)
        if not mm_input_by_modality:
847
848
            return None

849
850
851
852
853
854
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
855
856
857
858
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
859
                multimodal_embeddings += tuple(vision_embeddings)
860
861
            if modality == "video":
                video_embeddings = self._process_video_pixels(multimodal_input)
862
                multimodal_embeddings += tuple(video_embeddings)
863
864
865
866
867
868

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
869
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
870
871
872
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
873
874
875
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_index, self.config.video_token_index])
876
877
        return inputs_embeds

878
879
880
    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
881
882
        image_input: Optional[LlavaOnevisionImagePixelInputs] = None,
        video_input: Optional[LlavaOnevisionVideoPixelInputs] = None,
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_index,
            )

        if video_input is not None:
            video_embeds = self._process_video_pixels(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_index,
            )

        return inputs_embeds

905
906
907
908
909
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
910
        inputs_embeds: Optional[torch.Tensor] = None,
911
        **kwargs: object,
912
    ) -> Union[torch.Tensor, IntermediateTensors]:
913
914
915
916
917
918
        """Run forward pass for LlaVA-Onevision.
        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            pixel_values_videos: Pixels in each frames for each input videos.
        """
919
        if intermediate_tensors is not None:
920
            inputs_embeds = None
921

922
923
924
        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
925
        elif inputs_embeds is None:
926
927
928
929
930
931
932
933
934
935
936
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None
937
938
939

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
940
                                                  intermediate_tensors,
941
942
943
944
945
946
947
948
949
950
951
952
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

953
954
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
955
        loader = AutoWeightsLoader(self)
956
        return loader.load_weights(weights)