hyperclovax_vision.py 40.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
import ast
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
8
9
from itertools import accumulate
from typing import Annotated, Any, Literal, Optional, Union
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
14
from einops import rearrange
15
16
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
17
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
18
19
20
from transformers.modeling_utils import no_init_weights

from vllm.config import VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.cache import BaseMultiModalProcessorCache
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
26
                                    MultiModalKwargsItems)
27
28
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
29
30
31
                                        BaseProcessingInfo,
                                        InputProcessingContext,
                                        PromptReplacement, PromptUpdate)
32
33
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
34
from vllm.utils.tensor_schema import TensorSchema, TensorShape
35
36
37
38

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
39
40
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    maybe_prefix)
41
42
43
44
45
46
47
from .vision import get_vision_encoder_info

EOT = "<|endofturn|>"
IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Based on combine_frames_into_images in
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
def get_num_combined_frames(
        num_frames: int,
        max_grid_shape: tuple[int, int] = (3, 3),
) -> int:
    max_num_grids = max_grid_shape[0] * max_grid_shape[1]

    # Calculate the number of canvases needed.
    num_canvases = num_frames // max_num_grids
    leftover_frames = num_frames % max_num_grids

    return num_canvases + (leftover_frames > 0)


63
class HCXVisionImagePixelInputs(TensorSchema):
64
    """
65
66
67
68
69
70
    Dimensions:
        - n: Number of images
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
71
    """
72
73
74
75
76
77
78
79
80
81
82
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values_images: Annotated[
        list[torch.Tensor],
        TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
    image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]


HCXVisionImageInputs = HCXVisionImagePixelInputs


class HCXVisionVideoPixelInputs(TensorSchema):
83
    """
84
85
86
87
88
89
90
    Dimensions:
        - n: Number of videos
        - f: Number of frames
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
91
    """
92
93
94
95
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
    pixel_values_videos: Annotated[
        list[list[torch.Tensor]],
        TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
96
97


98
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157


class HCXVisionProcessingInfo(BaseProcessingInfo):

    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": None}

    def get_num_image_tokens(
        self,
        *,
        vision_query_length: Union[int, list[int]],
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_num_video_tokens(
        self,
        *,
        vision_query_length: Union[int, list[int]],
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class HCXVisionDummyInputsBuilder(
        BaseDummyInputsBuilder[HCXVisionProcessingInfo]):

    def get_dummy_text(
        self,
        mm_counts: Mapping[str, int],
    ) -> str:
        dummy_text = IMAGE_TOKEN * mm_counts.get(
            "image", 0) + VIDEO_TOKEN * mm_counts.get("video", 0)
        return dummy_text

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
158
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
159
160
161
162
163
164
165
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        target_num_frames = 32
166
167
168
169

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

170
171
172
173
174
175
        return {
            "image":
            self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
176
                overrides=image_overrides,
177
178
179
180
181
182
183
            ),
            "video":
            self._get_dummy_videos(
                width=target_width - 1,
                height=target_height - 1,
                num_frames=target_num_frames,
                num_videos=num_videos,
184
                overrides=video_overrides,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            )
        }


class HCXVisionMultiModalProcessor(
        BaseMultiModalProcessor[HCXVisionProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
199
        for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
200
201
            if video_arr.dtype != np.uint8:
                mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
202
203
204
205
206
207
208
209
210
211
212

        processed_outputs = self.info.ctx.call_hf_processor(
            hf_processor=self.info.get_hf_processor(**mm_kwargs),
            data=dict(
                text=prompt,
                images=None,
                videos=None,
            ),
        )  # text-only

        if len(mm_data) > 0:
213
214
            images = mm_data.get("images")
            videos = mm_data.get("videos")
215

216
            # batchify input as a single item
217
218
219
220
            _processed_outputs = self.info.ctx.call_hf_processor(
                hf_processor=self.info.get_hf_processor(**mm_kwargs),
                data=dict(
                    text=None,
221
222
                    images=None if images is None else [images],
                    videos=None if videos is None else [videos],
223
224
225
226
                ),
            )  # mm-only

            for k, v in _processed_outputs.items():
227
228
                if isinstance(v, list) and len(v) > 0:
                    assert len(v) == 1
229
                    _processed_outputs[k] = v[0]
230
231

            if images:
232
233
234
235
236
                _processed_outputs["image_sizes_images"] = torch.tensor(
                    _processed_outputs["image_sizes_images"])
                _processed_outputs[
                    "vision_query_lengths_images"] = torch.tensor(
                        _processed_outputs["vision_query_lengths_images"])
237

238
            if videos:
239
240
241
242
                _idx_per_video = [
                    0, *accumulate(
                        get_num_combined_frames(len(video))
                        for video in videos)
243
244
245
                ]
                _processed_outputs["pixel_values_videos"] = [
                    _processed_outputs["pixel_values_videos"]
246
247
                    [_idx_per_video[i]:_idx_per_video[i + 1]]
                    for i in range(len(videos))
248
249
                ]
                _processed_outputs["vision_query_lengths_videos"] = [
250
251
252
253
                    torch.tensor(
                        _processed_outputs["vision_query_lengths_videos"]
                        [_idx_per_video[i]:_idx_per_video[i + 1]])
                    for i in range(len(videos))
254
255
                ]

256
257
258
259
            processed_outputs.update(_processed_outputs)

        return processed_outputs

260
261
262
263
264
265
266
267
268
    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

269
270
271
272
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
273
        out_mm_kwargs: MultiModalKwargsItems,
274
275
276
277
278
279
280
281
282
283
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        placeholder = {
            "image": hf_config.image_token_id,
            "video": hf_config.video_token_id,
        }

        def get_replacement_hyperclovax(
            item_idx: int,
            modality: str,
284
            out_mm_kwargs: MultiModalKwargsItems,
285
        ):
286
287
            out_item = out_mm_kwargs[modality][item_idx]

288
            if modality == "image":
289
                lens = out_item["vision_query_lengths_images"].data.tolist()
290
                num_tokens = self.info.get_num_image_tokens(
291
292
                    vision_query_length=lens)
            elif modality == "video":
293
                lens = out_item["vision_query_lengths_videos"].data.tolist()
294
                num_tokens = self.info.get_num_video_tokens(
295
296
297
298
299
                    vision_query_length=lens)
            else:
                raise NotImplementedError(modality)

            return [placeholder[modality]] * num_tokens
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        return [
            PromptReplacement(
                modality=modality,
                target=[
                    placeholder[modality],
                ],
                replacement=partial(
                    get_replacement_hyperclovax,
                    modality=modality,
                    out_mm_kwargs=out_mm_kwargs,
                ),
            ) for modality in ("image", "video")
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values_images=MultiModalFieldConfig.batched("image"),
            image_sizes_images=MultiModalFieldConfig.batched("image"),
            vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
            vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
        )


def _build_hcxvision_hf_info(
    ctx: InputProcessingContext, ) -> HCXVisionProcessingInfo:
    return HCXVisionProcessingInfo(ctx)


def _build_hcxvision_hf_processor(
    info: HCXVisionProcessingInfo,
    dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
    *,
338
    cache: Optional[BaseMultiModalProcessorCache] = None,
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
) -> BaseMultiModalProcessor:
    if isinstance(info, HCXVisionProcessingInfo):
        return HCXVisionMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    raise NotImplementedError(type(info))


def init_vision_tower_for_hcxvision(
    vision_config,
    quant_config: Optional[QuantizationConfig],
    *,
    use_nth_layer: Optional[int] = None,
    require_post_norm: Optional[bool] = None,
    prefix: str = "",
) -> Union[CLIPVisionModel, SiglipVisionModel]:
    num_hidden_layers = vision_config.num_hidden_layers
    if not isinstance(use_nth_layer, int):
        pass
    elif use_nth_layer >= 0:
        num_hidden_layers = use_nth_layer + 1
    else:
        num_hidden_layers = num_hidden_layers + use_nth_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


class HCXVisionMlp(nn.Module):

    def __init__(
        self,
        mm_projector_type,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mm_projector_type = mm_projector_type
        if self.mm_projector_type == "mlp":
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
        elif self.mm_projector_type == "inverted_mlp":
            self.fc1 = nn.Linear(in_features, 2 * hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(2 * hidden_features, out_features)
        else:
            raise NotImplementedError("{} is not implemented".format(
                self.mm_projector_type))

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class HCXVisionCAbstractor(nn.Module):
    """
    This module is based on C-Abstractor, whose license is under apache-2.0.
    You can check the original code at 
    https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
    and we made necessary modifications.
    """

    def __init__(
        self,
        num_queries: int,
        num_input_tokens: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        pos_emb: bool = True,
        prenorm: bool = False,
    ):
        super().__init__()
        self.num_input_tokens = num_input_tokens
        self.output_hidden_size = output_hidden_size

        # Positional embedding
        if pos_emb:
            self.pos_emb = torch.nn.Parameter(
                torch.zeros(1, num_input_tokens, encoder_hidden_size))
            self.pos_emb.data.normal_(mean=0.0, std=0.02)
        else:
            self.pos_emb = None

        # (Optional) Pre-normalization layer
        if prenorm:
            self.prenorm = LayerNorm(encoder_hidden_size)
        else:
            self.prenorm = None

        self.build_net(num_queries, encoder_hidden_size, hidden_size,
                       output_hidden_size)
        self.dtype = next(self.parameters()).dtype

    def forward(
        self,
        x: torch.Tensor,
        num_queries_vis_abstractors: Optional[list[list[int]]] = None,
        num_grids: Optional[list[int]] = None,
    ) -> torch.Tensor:
        if self.prenorm is not None:
            x = self.prenorm(x)

        if self.pos_emb is not None:
            x = x + self.pos_emb

        x = self._forward(
            x,
            num_queries_vis_abstractors=num_queries_vis_abstractors,
            num_grids=num_grids,
        )  # (B, L, output_hidden_size)

        return x

    def _forward(
        self,
        x: torch.Tensor,
        num_queries_vis_abstractors: Optional[list[list[int]]] = None,
        num_grids: Optional[list[int]] = None,
    ) -> torch.Tensor:
        # x: [B, L, dim]
        B, L, dim = x.shape
        hw = int(L**0.5)
        x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)

        if num_queries_vis_abstractors is not None:
            assert num_grids is not None
            return self._forward_adaptive_num_query(
                x, num_queries_vis_abstractors, num_grids)

        x = self.net(x)
        x = rearrange(x, "b d h w -> b (h w) d")
        x = self.readout(x)
        return x

    def _forward_adaptive_num_query(
        self,
        x: torch.Tensor,
        num_queries_vis_abstractors: Optional[list[list[int]]] = None,
        num_grids: Optional[list[int]] = None,
    ) -> list[torch.Tensor]:
        # self.net is consisted by 3 layers (s1, sampler, s2)
        assert len(self.net) == 3

        x = self.net[0](x)  # s1
        new_x = []
        for i, num_queries in enumerate(num_queries_vis_abstractors):
            hw = int(num_queries**0.5)
            sampler = nn.AdaptiveAvgPool2d((hw, hw))
            out = sampler(x[num_grids[i]:num_grids[i + 1], :])
            out = self.net[2](out)  # s2

            out = rearrange(out, "b d h w -> b (h w) d")
            out = self.readout(out)

            new_x.append(out)
        return new_x

    def build_net(
        self,
        n_queries: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        depth: int = 3,
        mlp_depth: int = 2,
    ):
        assert (n_queries**0.5).is_integer(
        ), f"n_queries must be square number. n_queries: {n_queries}"
        hw = int(n_queries**0.5)

        # RegBlock = ResBlock + SE
        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = nn.AdaptiveAvgPool2d((hw, hw))
        s2 = RegBlock(
            depth,
            hidden_size,
            hidden_size,
        )

        self.net = nn.Sequential(s1, sampler, s2)
        self.readout = self.build_mlp(mlp_depth, hidden_size,
                                      output_hidden_size)

    def build_mlp(
        self,
        depth: int,
        hidden_size: int,
        output_hidden_size: int,
    ):
        layers = [nn.Linear(hidden_size, output_hidden_size)]
        for _ in range(1, depth):
            layers.append(nn.SiLU())
            layers.append(nn.Linear(output_hidden_size, output_hidden_size))
        return nn.Sequential(*layers)


@MULTIMODAL_REGISTRY.register_processor(
    _build_hcxvision_hf_processor,
    info=_build_hcxvision_hf_info,
    dummy_inputs=HCXVisionDummyInputsBuilder)
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
580
    merge_by_field_config = True
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        **kwargs: Optional[Any],
    ) -> None:
        super().__init__()

        # init configs
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        # text_config
        text_config = config.text_config
        if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
            text_config._attn_implementation = "sdpa"
        if text_config.model_type != "hyperclovax":
            text_config.logits_scaling = 1.0
        # vision_config
        vision_config = config.vision_config
        vision_config.auto_map = {}
        vision_config.anyres = config.anyres
        vision_config.max_num_grids = config.max_num_grids
        self.dtype = vllm_config.model_config.dtype

        ## possible_resolution should be matched with preprocessor_config.json
        config.possible_resolutions = self._init_possible_resolutions(
            config, vision_config)

        # init models & parameters
        with no_init_weights():  # weight will be loaded in from_pretrained
            self.vision_model = init_vision_tower_for_hcxvision(
                vision_config,
                quant_config,
                use_nth_layer=getattr(config, "use_nth_layer", -1),
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
        self.mm_projector = self._init_mm_projector(config, text_config,
                                                    vision_config)

        self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size",
                                          text_config.vocab_size)
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        if config.anyres:
            self.image_newline = nn.Parameter(
                torch.empty(text_config.hidden_size, dtype=self.dtype))

        self.config = config
        self.vision_config = vision_config
        self.text_config = text_config

        # use_sum_loss = bool(kwargs.pop("use_sum_loss", False))
        # self.reduction = self._init_reduction_type(use_sum_loss)

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return IMAGE_TOKEN
        if modality.startswith("video"):
            return VIDEO_TOKEN

        raise ValueError("Only image or video modality is supported")

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
    ) -> Optional[HCXVisionImageInputs]:
        pixel_values_images = kwargs.pop("pixel_values_images", None)

        if pixel_values_images is None:
            return None

        image_sizes_images = kwargs.pop("image_sizes_images")

        return HCXVisionImagePixelInputs(
            pixel_values_images=pixel_values_images,
            image_sizes_images=image_sizes_images,
        )

    def _parse_and_validate_video_input(
        self,
        **kwargs: object,
    ) -> Optional[HCXVisionVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)

        if pixel_values_videos is None:
            return None

        return HCXVisionVideoPixelInputs(
            pixel_values_videos=pixel_values_videos, )

    def _process_image_input(
        self,
        image_input: HCXVisionImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_images(
            pixel_values_images=image_input["pixel_values_images"],
            image_sizes_images=image_input["image_sizes_images"],
        )

    def _process_video_input(
        self,
        video_input: HCXVisionVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_videos(
            pixel_values_videos=video_input["pixel_values_videos"], )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (input_key == "pixel_values_images"
                    and "images" not in modalities):
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if (input_key == "pixel_values_videos"
                    and "videos" not in modalities):
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities

717
718
719
720
721
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
        self,
722
        **kwargs: object,
723
    ) -> MultiModalEmbeddings:
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
                multimodal_embeddings += video_embeddings
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764

        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
        return hidden_states

    def forward_images(
        self,
765
766
767
768
        pixel_values_images: list[torch.Tensor],
        image_sizes_images: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
769
770
771

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
        image_forward_outs = self.vision_model(
772
            pixel_values_image_flat)[:, visual_token_idx:]
773
774
775
776
777

        image_forward_outs = image_forward_outs.to(
            dtype=self.mm_projector.dtype)
        image_forward_outs = self.mm_projector(image_forward_outs)  # b (h w) d

778
        split_sizes = [len(item) for item in pixel_values_images]
779
780
781
782
783
784
785
        image_forward_outs = torch.split(image_forward_outs,
                                         split_sizes,
                                         dim=0)

        # newline for anyres postprocessing
        image_features = anyres_postprocessing(
            image_forward_outs=image_forward_outs,
786
            image_sizes=image_sizes_images.tolist(),
787
788
789
790
791
792
793
794
            num_queries_vis_abstractor=self.config.
            num_queries_vis_abstractor_image,
            unpad=self.config.unpad,
            patch_size=self.vision_config.patch_size,
            grid_size=self.vision_config.image_size,
            image_newline=self.image_newline,
            possible_resolutions=self.config.possible_resolutions,
        )
795
796

        return tuple(image_features)
797
798
799

    def forward_videos(
        self,
800
801
802
803
804
805
        pixel_values_videos: list[list[torch.Tensor]],
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_videos_flat = flatten_bn(
            [frame for frames in pixel_values_videos for frame in frames],
            concat=True,
        )
806
807
808

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
        video_forward_outs = self.vision_model(
809
            pixel_values_videos_flat)[:, visual_token_idx:]
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884

        video_forward_outs = video_forward_outs.to(
            dtype=self.mm_projector.dtype)

        # Run MM-Projector
        # len(num_grids) == len(num_queries_vis_abstractors) + 1
        grid_idx = 0
        num_grids = [
            grid_idx
        ]  # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
        num_queries_vis_abstractors = [
        ]  # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
        len_total_frames = video_forward_outs.shape[0]

        if self.config.first_last_frames_slow:
            # slowfast (first_last_frames_slow)
            assert len_total_frames != 0
            if len_total_frames <= 2:
                num_queries_vis_abstractors.append(
                    self.config.num_queries_vis_abstractor_video_slow)
                grid_idx += len_total_frames
                num_grids.append(grid_idx)
            else:
                num_queries_vis_abstractors.append(
                    self.config.num_queries_vis_abstractor_video_slow)
                grid_idx += 1
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
                    self.config.num_queries_vis_abstractor_video_fast)
                grid_idx += len_total_frames - 2
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
                    self.config.num_queries_vis_abstractor_video_slow)
                grid_idx += 1
                num_grids.append(grid_idx)
        else:
            # slowfast
            for pixel_values_frames in pixel_values_videos:
                for pixel_values_frame in pixel_values_frames:
                    if len(pixel_values_frame) > 0:
                        num_queries_vis_abstractors.append(
                            self.config.num_queries_vis_abstractor_video_slow)
                        grid_idx += 1
                        num_grids.append(grid_idx)
                        num_queries_vis_abstractors.append(
                            self.config.num_queries_vis_abstractor_video_fast)
                        grid_idx = grid_idx + len(pixel_values_frame) - 1
                        num_grids.append(grid_idx)

        video_forward_outs = self.mm_projector(video_forward_outs,
                                               num_queries_vis_abstractors,
                                               num_grids)

        video_features = []  # what we want to return
        target_features = []
        target_group_size = 0
        group_counter = 0
        video_groups = [
            len(frame) for frames in pixel_values_videos for frame in frames
        ]  # for concat video features after projector

        for forward_out in video_forward_outs:
            target_group_size += len(forward_out)
            target_features.append(forward_out.flatten(0, 1))

            video_group_size = video_groups[group_counter]
            if video_group_size == target_group_size:
                video_features.append(torch.cat(target_features, dim=0))
                target_features = []
                group_counter += 1
                target_group_size = 0

            elif video_group_size < target_group_size:
885
886
                raise RuntimeError(
                    f"{video_group_size=} < {target_group_size=}")
887
888
889
890
891

        assert len(target_features
                   ) == 0, f"target_features is not empty!! {target_features}"
        assert len(video_groups) == len(video_features)

892
893
894
895
896
        feats_per_video = [len(video) for video in pixel_values_videos]
        idxs_per_video = [0, *accumulate(feats_per_video)]
        return tuple(
            torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
            for i in range(len(feats_per_video)))
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932

    def _prepare_multimodal_kwargs(self, **kwargs: object):
        output = defaultdict(list)
        for k, v in kwargs.items():
            if len(v) < 1 or len(v[0]) < 1:
                continue  # if empty batch of empty sample

            new_k, is_video = k, False
            if (not k.endswith("_images") and not k.endswith("_videos")):
                pass
            else:
                new_k, is_video = k.split("_")[:-1], k.split("_")[-1]
                new_k = "_".join(new_k)
                is_video = is_video == "videos"

            for _sample_idx, _v in enumerate(v):  # batch -> sample
                if new_k not in ["pixel_values"]:
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                    _v = _v.detach().cpu().numpy().tolist()
                    output[new_k][_sample_idx] += _v
                elif isinstance(_v, torch.Tensor):
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                        output["is_videos"].append(list())
                    _v = list(torch.unbind(_v, dim=0))
                    output[new_k][_sample_idx] += _v
                    output["is_videos"][_sample_idx] += [
                        is_video,
                    ] * len(_v)
        return dict(output)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
933
        return self.language_model.compute_logits(hidden_states)
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def _init_possible_resolutions(
        self,
        config,
        vision_config,
    ):
        if not getattr(config, "possible_resolutions", []):
            possible_resolutions = []
            if config.anyres:
                assert config.max_num_grids > 0
                for i in range(1, config.max_num_grids + 1):
                    for j in range(1, config.max_num_grids + 1):
                        if i == 1 and j == 1 and not config.use_1x1_grid:
                            continue
                        if i * j <= config.max_num_grids:
                            possible_resolutions.append([i, j])

                possible_resolutions = [[
                    ys * vision_config.image_size,
                    xs * vision_config.image_size
                ] for ys, xs in possible_resolutions]
            return possible_resolutions
        else:
            return config.possible_resolutions

    def _init_mm_projector(
        self,
        config,
        text_config,
        vision_config,
    ):
        input_hidden_size = vision_config.hidden_size
        if config.mm_projector_type == "linear":
            mm_projector = nn.Linear(input_hidden_size,
                                     text_config.hidden_size)
            mm_projector.dtype = next(mm_projector.parameters()).dtype
        elif config.mm_projector_type == "cabstractor":
            mm_projector = HCXVisionCAbstractor(
                num_queries=config.num_queries_vis_abstractor_image,
                num_input_tokens=(vision_config.image_size //
                                  vision_config.patch_size)**2,
                encoder_hidden_size=input_hidden_size,
                hidden_size=input_hidden_size,
                output_hidden_size=text_config.hidden_size,
                pos_emb=config.proj_pos_emb,
                prenorm=config.proj_prenorm,
            )
        else:
            mm_projector = HCXVisionMlp(
                config.mm_projector_type,
                input_hidden_size,
                hidden_features=input_hidden_size,
                out_features=self.text_config.hidden_size,
            )
        return mm_projector


def unpad_image(tensor: torch.Tensor,
                original_size: tuple[int, int]) -> torch.Tensor:
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding:current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding:current_width - padding]

    return unpadded_tensor


def select_best_resolution(original_size: tuple,
                           possible_resolutions: list) -> tuple:
    original_height, original_width = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for height, width in possible_resolutions:
        scale = min(width / original_width, height / original_height)
        downscaled_width, downscaled_height = int(original_width * scale), int(
            original_height * scale)
        effective_resolution = min(downscaled_width * downscaled_height,
                                   original_width * original_height)
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (
                effective_resolution == max_effective_resolution
                and wasted_resolution < min_wasted_resolution):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (height, width)

    return best_fit


def get_anyres_image_grid_shape(
    image_size: tuple[int, int],
    grid_pinpoints: Union[str, list[tuple[int, int]]],
    patch_size: int,
) -> tuple[int, int]:
    possible_resolutions = grid_pinpoints if isinstance(
        grid_pinpoints, list) else ast.literal_eval(grid_pinpoints)

    original_width, original_height = image_size
    height, width = select_best_resolution((original_height, original_width),
                                           possible_resolutions)
    return width // patch_size, height // patch_size


def reshape_and_unpad_image_features(
    image_feature: torch.Tensor,
    height: int,
    width: int,
    image_size: tuple[int, int],
    possible_resolutions: list[tuple[int, int]],
    grid_size: int,
    unpad: bool,
    image_newline: torch.Tensor,
) -> torch.Tensor:
    base_image_feature = image_feature[0]
    image_feature = image_feature[1:]

1072
1073
    assert height * width == base_image_feature.shape[0], (
        f"{height=} * {width=} != {base_image_feature.shape[0]=}")
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101

    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
        image_size, possible_resolutions, grid_size)
    image_feature = image_feature.view(num_patch_height, num_patch_width,
                                       height, width, -1)

    if unpad:
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        image_feature = unpad_image(image_feature, image_size)
        image_feature = torch.cat(
            (
                image_feature,
                image_newline[:, None, None].expand(
                    *image_feature.shape[:-1], 1).to(image_feature.device),
            ),
            dim=-1,
        )
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
    else:
        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
        image_feature = image_feature.flatten(0, 3)
    image_feature = torch.cat((base_image_feature, image_feature), dim=0)

    return image_feature


def anyres_postprocessing(
1102
    image_forward_outs: list[torch.Tensor],
1103
1104
1105
1106
    image_sizes: list[list[int]],
    possible_resolutions: list[tuple[int, int]],
    patch_size: int,
    grid_size: int,
1107
    image_newline: torch.Tensor,
1108
1109
    num_queries_vis_abstractor: int = -1,
    unpad: bool = False,
1110
) -> list[torch.Tensor]:
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    height = width = grid_size // patch_size

    if num_queries_vis_abstractor > 0:
        assert (num_queries_vis_abstractor**0.5
                ).is_integer(), "n_queries must be square number"
        height = width = int(num_queries_vis_abstractor**0.5)

    # post-processing (unpad, add newline)
    new_image_features = []
    for image_idx, image_feature in enumerate(image_forward_outs):
        if image_feature.shape[0] > 1:
            image_feature = reshape_and_unpad_image_features(
                image_feature=image_feature,
                height=height,
                width=width,
                image_size=image_sizes[image_idx],
                possible_resolutions=possible_resolutions,
                grid_size=grid_size,  # Pass grid info if needed by helper
                unpad=unpad,
                image_newline=image_newline,
            )
        else:
            image_feature = image_feature[0]
            image_feature = torch.cat(
                (image_feature, image_newline[None].to(image_feature.device)),
                dim=0)
        new_image_features.append(image_feature)
1138
1139

    return new_image_features