llava_onevision.py 35.6 KB
Newer Older
1
import math
2
from functools import cached_property
3
4
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
                    Protocol, Set, Tuple, TypedDict, Union)
5

6
import numpy as np
7
8
import torch
import torch.nn as nn
9
10
from transformers import (BatchFeature, LlavaOnevisionConfig,
                          LlavaOnevisionProcessor)
11
12
13
14
15
from transformers.models.llava_onevision.modeling_llava_onevision import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
16
from vllm.config import VllmConfig
17
from vllm.model_executor.layers.activation import get_act_fn
Joe Runde's avatar
Joe Runde committed
18
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
19
20
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
22
23
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
                                   VideoEmbeddingItems, VideoProcessorItems)
24
25
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
26
27
28
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

29
from .clip import CLIPVisionModel
30
from .interfaces import SupportsMultiModal, SupportsPP
31
32
33
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
                         LlavaNextProcessingMixin)
34
from .siglip import SiglipVisionModel
35
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
36
                    maybe_prefix, merge_multimodal_embeddings)
37

38
39
40
# For profile run
_MAX_FRAMES_PER_VIDEO = 16

41
42
43
44
45

class LlavaOnevisionVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
46
    Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
47

48
49
50
    Note that `num_videos` may be different for each batch, and 'num_frames'
    may be different for each video, in which case the data is passed as a
    list instead of a batched tensor.
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    """


class LlavaOnevisionImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`

    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    image_sizes: NotRequired[torch.Tensor]
    """
    Shape: `(batch_size * num_images, 2)`

    This should be in `(height, width)` format.
    """


class LlavaOnevisionImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
                                  LlavaOnevisionImageEmbeddingInputs]

LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
                                  LlavaOnevisionVideoPixelInputs]


89
90
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
    video_token_index: Final[int]
91
92


93
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
94

95
96
    def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
        return self.ctx.get_hf_config(LlavaOnevisionConfig)
97

98
99
    def _get_hf_processor(self):
        return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
100
101
102
103
104
105
106
107
108
109

    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
110
        # NOTE: Use float32 to remain consistent with HF output
111
112
113
114
115
116
117
118
        current_height_f = np.float32(npatches * num_patch_height)
        current_width_f = np.float32(npatches * num_patch_width)

        original_width_f = np.float32(original_width)
        original_height_f = np.float32(original_height)

        original_aspect_ratio = original_width_f / original_height_f
        current_aspect_ratio = current_width_f / current_height_f
119

120
        if original_aspect_ratio > current_aspect_ratio:
121
122
123
124
            scale_factor = current_width_f / original_width_f
            new_height = int(original_height_f * scale_factor)
            padding = (current_height_f - new_height) // 2
            current_height_f -= 2 * padding
125
        else:
126
127
128
129
            scale_factor = current_height_f / original_height_f
            new_width = int(original_width_f * scale_factor)
            padding = (current_width_f - new_width) // 2
            current_width_f -= 2 * padding
130

131
132
        unpadded_features = int(current_height_f * current_width_f)
        newline_features = int(current_height_f)
133

134
135
        ratio = math.sqrt(current_height_f * current_width_f /
                          (9 * npatches**2))
136
        if ratio > 1.1:
137
138
139
140
            height_factor = int(current_height_f // ratio)
            width_factor = int(current_width_f // ratio)
            unpadded_features = height_factor * width_factor
            newline_features = height_factor
141
142
143
144
145
146
147
148
149
150
151
152

        return (unpadded_features, newline_features)

    def _get_num_frame_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self._get_hf_config()
        spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)

153
154
        vision_encoder_info = self._get_vision_encoder_info()
        patch_grid_length = vision_encoder_info.get_patch_grid_length()
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)

        return pooled_grid_length * pooled_grid_length

    def _get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        num_frame_tokens = self._get_num_frame_tokens(
            image_width=image_width,
            image_height=image_height,
169
170
        )

171
        return num_frame_tokens * num_frames + 1  # Newline token
172

173
174
175
176

class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
                                  BaseLlavaProfilingInfo):

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def _get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self._get_hf_config()
        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
            feat_size = self._get_num_image_tokens(image_width=width,
                                                   image_height=height)
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

193
194
195
196
197
198
199
200
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        return {
            "image": self._get_max_image_tokens(),
            "video": self._get_max_video_tokens(seq_len),
        }
201

202
    def _get_max_video_frames(self, max_tokens: int) -> int:
203
204
        target_width, target_height = self._get_image_size_with_most_features()

205
        num_frames = 0
206

207
208
        while True:
            next_num_frames = num_frames + 1
209
210
211
212
213
            next_max_tokens = self._get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
214

215
            if next_max_tokens > max_tokens:
216
                break
217

218
            num_frames = next_num_frames
219

220
221
222
223
224
225
226
227
228
229
        return num_frames

    def _get_dummy_num_frames(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.limit_per_prompt.get("image", 1)
        max_videos = mm_config.limit_per_prompt.get("video", 1)

        max_image_tokens = self._get_max_image_tokens() * max_images
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
230
231
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)
232

233
        return max(max_frames_per_video, 1)
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def _get_max_video_tokens(self, seq_len: int) -> int:
        target_width, target_height = self._get_image_size_with_most_features()

        return self._get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=self._get_dummy_num_frames(seq_len),
        )

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        processor = self._get_hf_processor()
        image_token = processor.image_token
        video_token = processor.video_token
        target_width, target_height = self._get_image_size_with_most_features()

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=self._get_dummy_num_frames(seq_len),
                num_videos=num_videos,
            )
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images + video_token * num_videos,
            mm_data=mm_data,
        )


class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
                                        LlavaNextMultiModalProcessor):

    def _get_profiling_info(self) -> BaseProfilingInfo:
        return LlavaOnevisionProfilingInfo(self.ctx)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
        )
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        videos = mm_data.pop("videos", [])
        assert isinstance(videos, list)

        if not videos:
            return super()._call_hf_processor(
                prompt=prompt,
                mm_data=mm_data,
                mm_kwargs=mm_kwargs,
            )

312
313
        processor = self._get_hf_processor()
        video_token = processor.video_token
314
315
316
317
318
319
320

        # LLaVA-OneVision processor doesn't support multiple videos
        # with different sizes when converting back to tensors
        text_image_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
321
322
        )

323
324
325
326
327
328
329
330
331
        pixel_values_videos = []
        for video in videos:
            item_processor_data = dict(prompt=video_token, videos=video)

            item_outputs = super()._call_hf_processor(
                prompt=prompt,
                mm_data=item_processor_data,
                mm_kwargs=mm_kwargs,
            )
332

333
334
            pixel_values_videos.append(
                item_outputs.pop("pixel_values_videos")[0])
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        combined_outputs = dict(
            **text_image_outputs,
            pixel_values_videos=pixel_values_videos,
        )
        return BatchFeature(combined_outputs)

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
        image_repls = super()._get_prompt_replacements(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
353

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        hf_config = self._get_hf_config()
        video_token_id = hf_config.video_token_index

        def get_video_replacement(item_idx: int):
            videos = mm_items.get_items(
                "video", (VideoEmbeddingItems, VideoProcessorItems))

            if isinstance(videos, VideoEmbeddingItems):
                num_video_tokens = videos.get_feature_size(item_idx)
            else:
                image_size = videos.get_frame_size(item_idx)
                num_video_tokens = self._get_num_video_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    num_frames=videos.get_num_frames(item_idx),
                )

            return [video_token_id] * num_video_tokens

        return image_repls + [
            PromptReplacement(
                modality="video",
                target=[video_token_id],
                replacement=get_video_replacement,
            ),
        ]

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

class LlavaOnevisionMultiModalProjector(nn.Module):

    def __init__(self, config: LlavaOnevisionConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.vision_config.hidden_size,
                                  config.text_config.hidden_size,
                                  bias=True)
        self.act = get_act_fn(config.projector_hidden_act)
        self.linear_2 = nn.Linear(config.text_config.hidden_size,
                                  config.text_config.hidden_size,
                                  bias=True)

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


402
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
403
404
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
                                             SupportsPP):
405

406
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
407
        super().__init__()
408
409
410
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
411
412
413
414
415

        self.config = config
        self.multimodal_config = multimodal_config

        # Initialize the vision tower only up to the required feature layer
416
        self.vision_tower = init_vision_tower_for_llava(
417
418
419
            config,
            quant_config,
            require_post_norm=False,
420
            prefix=maybe_prefix(prefix, "vision_tower"))
421
422
        self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
        self.language_model = init_vllm_registered_model(
423
            vllm_config=vllm_config,
424
425
426
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
427
428
429
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

430
431
432
433
434
435
436
437
        self.make_empty_intermediate_tensors = (
            self.language_model.model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
438
        return get_sampler()
439

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _validate_image_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            if not isinstance(image_sizes, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return LlavaOnevisionImagePixelInputs(
                type="pixel_values",
                data=self._validate_image_pixel_values(
                    flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaOnevisionImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds),
            )

        raise AssertionError("This line should be unreachable.")

    def _validate_video_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[2:])

            if actual_dims != expected_dims:
                expected_expr = ("num_frames", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values in each video frame "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_video_input(
            self,
            **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
        """
        A legal video input should have the following dimensions:
        {
            "pixel_values_videos" : 
                List[b, Tensor(nb_frames, nb_channels, height, width)]
        }
        """
        pixel_values = kwargs.pop("pixel_values_videos", None)

        if pixel_values is None:
            return None

        if not (is_list_of(pixel_values,
                           (torch.Tensor))  # different shape videos 
                or isinstance(pixel_values,
                              torch.Tensor)):  # same shape videos
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

        return LlavaOnevisionVideoPixelInputs(
            type="pixel_values_videos",
            data=pixel_values,
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

567
568
569
570
571
572
573
574
575
        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key == "pixel_values" and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key == "pixel_values_videos" and "videos" not in modalities:  # noqa E501
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746

        return modalities

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        image_features = vision_tower(pixel_values)
        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
    def _merge_image_patch_embeddings(self,
                                      image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor,
                                      *,
                                      image_newline=None,
                                      vision_aspect_ratio="anyres_max_9",
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

                # image_aspect_ratio == "anyres"
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
                    .view(num_patch_height, num_patch_width, height, width, -1)

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
                                                     (orig_height, orig_width))
                    max_num_patches = int(
                        vision_aspect_ratio.removeprefix("anyres_max_"))
                    channels, curr_height, curr_width = other_patch_embeds.shape
                    ratio = math.sqrt(curr_height * curr_width /
                                      (max_num_patches * height**2))
                    if ratio > 1.1:
                        other_patch_embeds = other_patch_embeds[None]
                        other_patch_embeds = nn.functional.interpolate(
                            other_patch_embeds, [
                                int(curr_height // ratio),
                                int(curr_width // ratio)
                            ],
                            mode="bilinear")[0]
                    if image_newline is not None:
                        other_patch_embeds = torch.cat(
                            (
                                other_patch_embeds,
                                image_newline[:, None, None] \
                                .expand(*other_patch_embeds.shape[:-1], 1) \
                                .to(other_patch_embeds.device),
                            ),
                        dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
        self,
        inputs: LlavaOnevisionImagePixelInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)

            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]

    def _process_image_input(
        self,
        image_input: LlavaOnevisionImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

        patch_embeddings = self._process_image_pixels(image_input)

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
            batch_size = len(image_input["data"])
            vision_config = self.config.vision_config
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
                                           for _ in range(batch_size)])

        return [
            self._merge_image_patch_embeddings(
                image_sizes[i],
                patch_features_batch,
                image_newline=self.image_newline,
                strategy="spatial_unpad")
            for i, patch_features_batch in enumerate(patch_embeddings)
        ]

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    def _add_image_newline(
        self,
        video_features: torch.Tensor,
        videos: int = 1,
        frames: int = 1,
        strategy: str = "one_token",
    ) -> torch.Tensor:
        if strategy == "one_token":
            video_features = video_features.reshape(
                videos, frames * video_features.shape[1], -1)
            image_newline = self.image_newline[None, None, :].repeat(
                videos, 1, 1).to(video_features.device)
            video_features = torch.cat((video_features, image_newline), dim=1)
            return video_features
        raise ValueError(f"Unexpected video newline strategy: {strategy}")

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    def _video_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        video_features = vision_tower(pixel_values)
        video_features = self._select_image_features(
            video_features,
            strategy=self.config.vision_feature_select_strategy,
        )
        video_features = self.multi_modal_projector(video_features)
        video_features = self.apply_pooling(video_features)
        return video_features

    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
        assert self.vision_tower is not None

        video_pixels = inputs["data"]

        if isinstance(video_pixels, torch.Tensor):
786
787
            b, num_videos, frames, c, h, w = video_pixels.shape
            pixel_values = video_pixels.view(b * num_videos * frames, c, h, w)
788
            stacked_embeddings = self._video_pixels_to_features(
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
                self.vision_tower, pixel_values)
            stacked_embeddings = self._add_image_newline(stacked_embeddings,
                                                         videos=b * num_videos,
                                                         frames=frames,
                                                         strategy="one_token")
            return stacked_embeddings
        elif is_list_of(video_pixels, torch.Tensor):
            stacked_embeddings = []
            for video_pixel in video_pixels:
                num_videos, frames, c, h, w = video_pixel.shape
                pixel_values = video_pixel.view(num_videos * frames, c, h, w)
                embeddings = self._video_pixels_to_features(
                    self.vision_tower, pixel_values)
                embeddings = self._add_image_newline(embeddings,
                                                     videos=num_videos,
                                                     frames=frames,
                                                     strategy="one_token")
                stacked_embeddings.append(embeddings)
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
            return stacked_embeddings
        else:
            raise ValueError(
                f"Unsupported type of video input {type(video_pixels)}")

    def apply_pooling(self, image_features, stride=2):
        vision_config = self.config.vision_config
        height = width = vision_config.image_size // vision_config.patch_size
        batch_frames, _, dim = image_features.shape
        image_features = image_features.view(batch_frames, height, width, -1)
        image_features = image_features.permute(0, 3, 1, 2)

        # TODO support other pooling types config
        height, width = image_features.shape[2:]
        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
        image_feature = nn.functional.interpolate(image_features,
                                                  size=scaled_shape,
                                                  mode='bilinear')
        image_feature = image_feature.permute(0, 2, 3, 1)
        image_feature = image_feature.view(batch_frames, -1, dim)
        return image_feature

829
830
831
832
833
834
    def get_multimodal_embeddings(
            self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(vision_embeddings)
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_pixels(video_input)
                multimodal_embeddings += tuple(video_embeddings)
850
851
852
853
854
855
856
857
858
859
860

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[List[Tuple[NestedTensors,
                                                   str]]] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
861
862
863
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                [self.config.image_token_index, self.config.video_token_index])
864
865
        return inputs_embeds

866
867
868
869
870
871
872
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
873
        inputs_embeds: Optional[torch.Tensor] = None,
874
        **kwargs: object,
875
    ) -> Union[torch.Tensor, IntermediateTensors]:
876
877
878
879
880
881
        """Run forward pass for LlaVA-Onevision.
        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            pixel_values_videos: Pixels in each frames for each input videos.
        """
882
        if intermediate_tensors is not None:
883
            inputs_embeds = None
884
885
886
887
888
889
890
891

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings)
            input_ids = None
892
893
894
895
896

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
897
                                                  intermediate_tensors,
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
                                                  inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

917
918
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
919
        loader = AutoWeightsLoader(self)
920
        return loader.load_weights(weights)