hyperclovax_vision.py 39.7 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
import ast
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
8
from itertools import accumulate
9
from typing import Annotated, Literal
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
14
from einops import rearrange
15
16
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
17
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
18
19
20
from transformers.modeling_utils import no_init_weights

from vllm.config import VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.cache import BaseMultiModalProcessorCache
25
26
27
28
29
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
30
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
31
from vllm.multimodal.processing import (
32
    BaseDummyInputsBuilder,
33
34
35
36
37
38
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
)
39
from vllm.sequence import IntermediateTensors
40
from vllm.utils.tensor_schema import TensorSchema, TensorShape
41
42
43
44

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
45
46
47
48
49
50
from .utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
51
52
53
54
55
56
57
from .vision import get_vision_encoder_info

EOT = "<|endofturn|>"
IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"


58
59
60
# Based on combine_frames_into_images in
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
def get_num_combined_frames(
61
62
    num_frames: int,
    max_grid_shape: tuple[int, int] = (3, 3),
63
64
65
66
67
68
69
70
71
72
) -> int:
    max_num_grids = max_grid_shape[0] * max_grid_shape[1]

    # Calculate the number of canvases needed.
    num_canvases = num_frames // max_num_grids
    leftover_frames = num_frames % max_num_grids

    return num_canvases + (leftover_frames > 0)


73
class HCXVisionImagePixelInputs(TensorSchema):
74
    """
75
76
77
78
79
80
    Dimensions:
        - n: Number of images
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
81
    """
82

83
84
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values_images: Annotated[
85
86
        list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})
    ]
87
88
89
90
91
92
93
    image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]


HCXVisionImageInputs = HCXVisionImagePixelInputs


class HCXVisionVideoPixelInputs(TensorSchema):
94
    """
95
96
97
98
99
100
101
    Dimensions:
        - n: Number of videos
        - f: Number of frames
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
102
    """
103

104
105
106
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
    pixel_values_videos: Annotated[
        list[list[torch.Tensor]],
107
108
        TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}),
    ]
109
110


111
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
112
113
114
115
116
117


class HCXVisionProcessingInfo(BaseProcessingInfo):
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

118
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
119
120
121
122
123
        return {"image": None, "video": None}

    def get_num_image_tokens(
        self,
        *,
124
        vision_query_length: int | list[int],
125
126
127
128
129
130
131
132
133
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_num_video_tokens(
        self,
        *,
134
        vision_query_length: int | list[int],
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


155
class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
156
157
158
159
160
    def get_dummy_text(
        self,
        mm_counts: Mapping[str, int],
    ) -> str:
        dummy_text = IMAGE_TOKEN * mm_counts.get(
161
162
            "image", 0
        ) + VIDEO_TOKEN * mm_counts.get("video", 0)
163
164
165
166
167
168
        return dummy_text

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
169
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
170
171
172
173
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

174
        target_width, target_height = self.info.get_image_size_with_most_features()
175
        target_num_frames = 32
176
177
178
179

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

180
        return {
181
            "image": self._get_dummy_images(
182
183
184
                width=target_width,
                height=target_height,
                num_images=num_images,
185
                overrides=image_overrides,
186
            ),
187
            "video": self._get_dummy_videos(
188
189
190
191
                width=target_width - 1,
                height=target_height - 1,
                num_frames=target_num_frames,
                num_videos=num_videos,
192
                overrides=video_overrides,
193
            ),
194
195
196
        }


197
class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]):
198
199
200
201
202
203
204
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
205
        for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
206
207
            if video_arr.dtype != np.uint8:
                mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
208
209
210
211
212
213
214
215
216
217
218

        processed_outputs = self.info.ctx.call_hf_processor(
            hf_processor=self.info.get_hf_processor(**mm_kwargs),
            data=dict(
                text=prompt,
                images=None,
                videos=None,
            ),
        )  # text-only

        if len(mm_data) > 0:
219
220
            images = mm_data.get("images")
            videos = mm_data.get("videos")
221

222
            # batchify input as a single item
223
224
225
226
            _processed_outputs = self.info.ctx.call_hf_processor(
                hf_processor=self.info.get_hf_processor(**mm_kwargs),
                data=dict(
                    text=None,
227
228
                    images=None if images is None else [images],
                    videos=None if videos is None else [videos],
229
230
231
232
                ),
            )  # mm-only

            for k, v in _processed_outputs.items():
233
234
                if isinstance(v, list) and len(v) > 0:
                    assert len(v) == 1
235
                    _processed_outputs[k] = v[0]
236
237

            if images:
238
                _processed_outputs["image_sizes_images"] = torch.tensor(
239
240
241
242
243
                    _processed_outputs["image_sizes_images"]
                )
                _processed_outputs["vision_query_lengths_images"] = torch.tensor(
                    _processed_outputs["vision_query_lengths_images"]
                )
244

245
            if videos:
246
                _idx_per_video = [
247
248
249
250
                    0,
                    *accumulate(
                        get_num_combined_frames(len(video)) for video in videos
                    ),
251
252
                ]
                _processed_outputs["pixel_values_videos"] = [
253
254
255
                    _processed_outputs["pixel_values_videos"][
                        _idx_per_video[i] : _idx_per_video[i + 1]
                    ]
256
                    for i in range(len(videos))
257
258
                ]
                _processed_outputs["vision_query_lengths_videos"] = [
259
                    torch.tensor(
260
261
262
263
                        _processed_outputs["vision_query_lengths_videos"][
                            _idx_per_video[i] : _idx_per_video[i + 1]
                        ]
                    )
264
                    for i in range(len(videos))
265
266
                ]

267
268
269
270
            processed_outputs.update(_processed_outputs)

        return processed_outputs

271
272
273
274
275
276
277
278
279
    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

280
281
282
283
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
284
        out_mm_kwargs: MultiModalKwargsItems,
285
286
287
288
289
290
291
292
293
294
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        placeholder = {
            "image": hf_config.image_token_id,
            "video": hf_config.video_token_id,
        }

        def get_replacement_hyperclovax(
            item_idx: int,
            modality: str,
295
            out_mm_kwargs: MultiModalKwargsItems,
296
        ):
297
298
            out_item = out_mm_kwargs[modality][item_idx]

299
            if modality == "image":
300
                lens = out_item["vision_query_lengths_images"].data.tolist()
301
                num_tokens = self.info.get_num_image_tokens(vision_query_length=lens)
302
            elif modality == "video":
303
                lens = out_item["vision_query_lengths_videos"].data.tolist()
304
                num_tokens = self.info.get_num_video_tokens(vision_query_length=lens)
305
306
307
308
            else:
                raise NotImplementedError(modality)

            return [placeholder[modality]] * num_tokens
309
310
311
312
313
314
315
316
317
318
319
320

        return [
            PromptReplacement(
                modality=modality,
                target=[
                    placeholder[modality],
                ],
                replacement=partial(
                    get_replacement_hyperclovax,
                    modality=modality,
                    out_mm_kwargs=out_mm_kwargs,
                ),
321
322
            )
            for modality in ("image", "video")
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values_images=MultiModalFieldConfig.batched("image"),
            image_sizes_images=MultiModalFieldConfig.batched("image"),
            vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
            vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
        )


def _build_hcxvision_hf_info(
340
341
    ctx: InputProcessingContext,
) -> HCXVisionProcessingInfo:
342
343
344
345
346
347
348
    return HCXVisionProcessingInfo(ctx)


def _build_hcxvision_hf_processor(
    info: HCXVisionProcessingInfo,
    dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
    *,
349
    cache: BaseMultiModalProcessorCache | None = None,
350
351
352
353
354
355
356
357
358
359
360
361
362
) -> BaseMultiModalProcessor:
    if isinstance(info, HCXVisionProcessingInfo):
        return HCXVisionMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    raise NotImplementedError(type(info))


def init_vision_tower_for_hcxvision(
    vision_config,
363
    quant_config: QuantizationConfig | None,
364
    multimodal_config: MultiModalConfig | None,
365
    *,
366
367
    use_nth_layer: int | None = None,
    require_post_norm: bool | None = None,
368
    prefix: str = "",
369
) -> CLIPVisionModel | SiglipVisionModel:
370
371
372
373
374
375
376
377
378
379
380
381
    num_hidden_layers = vision_config.num_hidden_layers
    if not isinstance(use_nth_layer, int):
        pass
    elif use_nth_layer >= 0:
        num_hidden_layers = use_nth_layer + 1
    else:
        num_hidden_layers = num_hidden_layers + use_nth_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
382
            multimodal_config=multimodal_config,
383
384
385
386
387
388
389
390
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
391
            multimodal_config=multimodal_config,
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


class HCXVisionMlp(nn.Module):
    def __init__(
        self,
        mm_projector_type,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mm_projector_type = mm_projector_type
        if self.mm_projector_type == "mlp":
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
        elif self.mm_projector_type == "inverted_mlp":
            self.fc1 = nn.Linear(in_features, 2 * hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(2 * hidden_features, out_features)
        else:
423
424
425
            raise NotImplementedError(
                "{} is not implemented".format(self.mm_projector_type)
            )
426
427
428
429
430
431
432
433
434
435
436

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class HCXVisionCAbstractor(nn.Module):
    """
    This module is based on C-Abstractor, whose license is under apache-2.0.
437
    You can check the original code at
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
    and we made necessary modifications.
    """

    def __init__(
        self,
        num_queries: int,
        num_input_tokens: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        pos_emb: bool = True,
        prenorm: bool = False,
    ):
        super().__init__()
        self.num_input_tokens = num_input_tokens
        self.output_hidden_size = output_hidden_size

        # Positional embedding
        if pos_emb:
            self.pos_emb = torch.nn.Parameter(
459
460
                torch.zeros(1, num_input_tokens, encoder_hidden_size)
            )
461
462
463
464
465
466
467
468
469
470
            self.pos_emb.data.normal_(mean=0.0, std=0.02)
        else:
            self.pos_emb = None

        # (Optional) Pre-normalization layer
        if prenorm:
            self.prenorm = LayerNorm(encoder_hidden_size)
        else:
            self.prenorm = None

471
472
473
        self.build_net(
            num_queries, encoder_hidden_size, hidden_size, output_hidden_size
        )
474
475
476
477
478
        self.dtype = next(self.parameters()).dtype

    def forward(
        self,
        x: torch.Tensor,
479
480
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    ) -> torch.Tensor:
        if self.prenorm is not None:
            x = self.prenorm(x)

        if self.pos_emb is not None:
            x = x + self.pos_emb

        x = self._forward(
            x,
            num_queries_vis_abstractors=num_queries_vis_abstractors,
            num_grids=num_grids,
        )  # (B, L, output_hidden_size)

        return x

    def _forward(
        self,
        x: torch.Tensor,
499
500
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
501
502
503
504
505
506
507
508
509
    ) -> torch.Tensor:
        # x: [B, L, dim]
        B, L, dim = x.shape
        hw = int(L**0.5)
        x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)

        if num_queries_vis_abstractors is not None:
            assert num_grids is not None
            return self._forward_adaptive_num_query(
510
511
                x, num_queries_vis_abstractors, num_grids
            )
512
513
514
515
516
517
518
519
520

        x = self.net(x)
        x = rearrange(x, "b d h w -> b (h w) d")
        x = self.readout(x)
        return x

    def _forward_adaptive_num_query(
        self,
        x: torch.Tensor,
521
522
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
523
524
525
526
527
528
529
530
531
    ) -> list[torch.Tensor]:
        # self.net is consisted by 3 layers (s1, sampler, s2)
        assert len(self.net) == 3

        x = self.net[0](x)  # s1
        new_x = []
        for i, num_queries in enumerate(num_queries_vis_abstractors):
            hw = int(num_queries**0.5)
            sampler = nn.AdaptiveAvgPool2d((hw, hw))
532
            out = sampler(x[num_grids[i] : num_grids[i + 1], :])
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            out = self.net[2](out)  # s2

            out = rearrange(out, "b d h w -> b (h w) d")
            out = self.readout(out)

            new_x.append(out)
        return new_x

    def build_net(
        self,
        n_queries: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        depth: int = 3,
        mlp_depth: int = 2,
    ):
550
551
552
        assert (n_queries**0.5).is_integer(), (
            f"n_queries must be square number. n_queries: {n_queries}"
        )
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        hw = int(n_queries**0.5)

        # RegBlock = ResBlock + SE
        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = nn.AdaptiveAvgPool2d((hw, hw))
        s2 = RegBlock(
            depth,
            hidden_size,
            hidden_size,
        )

        self.net = nn.Sequential(s1, sampler, s2)
577
        self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594

    def build_mlp(
        self,
        depth: int,
        hidden_size: int,
        output_hidden_size: int,
    ):
        layers = [nn.Linear(hidden_size, output_hidden_size)]
        for _ in range(1, depth):
            layers.append(nn.SiLU())
            layers.append(nn.Linear(output_hidden_size, output_hidden_size))
        return nn.Sequential(*layers)


@MULTIMODAL_REGISTRY.register_processor(
    _build_hcxvision_hf_processor,
    info=_build_hcxvision_hf_info,
595
596
    dummy_inputs=HCXVisionDummyInputsBuilder,
)
597
598
599
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
600
        "gate_up_proj": ["gate_proj", "up_proj"],
601
602
    }

603
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
604
605
606
607
608
        super().__init__()

        # init configs
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
609
        multimodal_config = vllm_config.model_config.multimodal_config
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        # text_config
        text_config = config.text_config
        if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
            text_config._attn_implementation = "sdpa"
        if text_config.model_type != "hyperclovax":
            text_config.logits_scaling = 1.0
        # vision_config
        vision_config = config.vision_config
        vision_config.auto_map = {}
        vision_config.anyres = config.anyres
        vision_config.max_num_grids = config.max_num_grids
        self.dtype = vllm_config.model_config.dtype

        ## possible_resolution should be matched with preprocessor_config.json
        config.possible_resolutions = self._init_possible_resolutions(
625
626
            config, vision_config
        )
627
628
629
630
631

        # init models & parameters
        with no_init_weights():  # weight will be loaded in from_pretrained
            self.vision_model = init_vision_tower_for_hcxvision(
                vision_config,
632
633
                quant_config=quant_config,
                multimodal_config=multimodal_config,
634
635
636
637
                use_nth_layer=getattr(config, "use_nth_layer", -1),
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
638
        self.mm_projector = self._init_mm_projector(config, text_config, vision_config)
639

640
641
642
        self.lm_head_vocab_size = getattr(
            text_config, "padded_vocab_size", text_config.vocab_size
        )
643
644
645
646
647
648
649
650
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        if config.anyres:
            self.image_newline = nn.Parameter(
651
652
                torch.empty(text_config.hidden_size, dtype=self.dtype)
            )
653
654
655
656
657
658
659
660
661

        self.config = config
        self.vision_config = vision_config
        self.text_config = text_config

        # use_sum_loss = bool(kwargs.pop("use_sum_loss", False))
        # self.reduction = self._init_reduction_type(use_sum_loss)

    @classmethod
662
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
663
664
665
666
667
668
669
        if modality.startswith("image"):
            return IMAGE_TOKEN
        if modality.startswith("video"):
            return VIDEO_TOKEN

        raise ValueError("Only image or video modality is supported")

670
671
672
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
673
    ) -> HCXVisionImageInputs | None:
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        pixel_values_images = kwargs.pop("pixel_values_images", None)

        if pixel_values_images is None:
            return None

        image_sizes_images = kwargs.pop("image_sizes_images")

        return HCXVisionImagePixelInputs(
            pixel_values_images=pixel_values_images,
            image_sizes_images=image_sizes_images,
        )

    def _parse_and_validate_video_input(
        self,
        **kwargs: object,
689
    ) -> HCXVisionVideoInputs | None:
690
691
692
693
694
695
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)

        if pixel_values_videos is None:
            return None

        return HCXVisionVideoPixelInputs(
696
697
            pixel_values_videos=pixel_values_videos,
        )
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712

    def _process_image_input(
        self,
        image_input: HCXVisionImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_images(
            pixel_values_images=image_input["pixel_values_images"],
            image_sizes_images=image_input["image_sizes_images"],
        )

    def _process_video_input(
        self,
        video_input: HCXVisionVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_videos(
713
714
            pixel_values_videos=video_input["pixel_values_videos"],
        )
715
716
717
718
719
720
721

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
722
723
724
725
            if input_key == "pixel_values_images" and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key == "pixel_values_videos" and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
726
727
728

        return modalities

729
730
731
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

732
    def embed_multimodal(
733
        self,
734
        **kwargs: object,
735
    ) -> MultiModalEmbeddings:
736
737
738
739
740
741
742
743
744
745
746
747
748
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
749
750
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
751
752
753
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
754
                multimodal_embeddings += tuple(video_embeddings)
755
756
757
758
759
760
761

        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
762
763
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
764
        **kwargs: object,
765
    ) -> torch.Tensor | IntermediateTensors:
766
767
768
        if intermediate_tensors is not None:
            inputs_embeds = None

769
770
771
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
772
773
774
775
        return hidden_states

    def forward_images(
        self,
776
777
778
779
        pixel_values_images: list[torch.Tensor],
        image_sizes_images: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
780
781

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
782
783
784
        image_forward_outs = self.vision_model(pixel_values_image_flat)[
            :, visual_token_idx:
        ]
785

786
        image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
787
788
        image_forward_outs = self.mm_projector(image_forward_outs)  # b (h w) d

789
        split_sizes = [len(item) for item in pixel_values_images]
790
        image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0)
791
792
793
794

        # newline for anyres postprocessing
        image_features = anyres_postprocessing(
            image_forward_outs=image_forward_outs,
795
            image_sizes=image_sizes_images.tolist(),
796
            num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image,
797
798
799
800
801
802
            unpad=self.config.unpad,
            patch_size=self.vision_config.patch_size,
            grid_size=self.vision_config.image_size,
            image_newline=self.image_newline,
            possible_resolutions=self.config.possible_resolutions,
        )
803
804

        return tuple(image_features)
805
806
807

    def forward_videos(
        self,
808
809
810
811
812
813
        pixel_values_videos: list[list[torch.Tensor]],
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_videos_flat = flatten_bn(
            [frame for frames in pixel_values_videos for frame in frames],
            concat=True,
        )
814
815

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
816
817
818
        video_forward_outs = self.vision_model(pixel_values_videos_flat)[
            :, visual_token_idx:
        ]
819

820
        video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype)
821
822
823
824

        # Run MM-Projector
        # len(num_grids) == len(num_queries_vis_abstractors) + 1
        grid_idx = 0
825
826
827
828
        # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
        num_grids = [grid_idx]
        # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
        num_queries_vis_abstractors = []
829
830
831
832
833
834
835
        len_total_frames = video_forward_outs.shape[0]

        if self.config.first_last_frames_slow:
            # slowfast (first_last_frames_slow)
            assert len_total_frames != 0
            if len_total_frames <= 2:
                num_queries_vis_abstractors.append(
836
837
                    self.config.num_queries_vis_abstractor_video_slow
                )
838
839
840
841
                grid_idx += len_total_frames
                num_grids.append(grid_idx)
            else:
                num_queries_vis_abstractors.append(
842
843
                    self.config.num_queries_vis_abstractor_video_slow
                )
844
845
846
847
                grid_idx += 1
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
848
849
                    self.config.num_queries_vis_abstractor_video_fast
                )
850
851
852
853
                grid_idx += len_total_frames - 2
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
854
855
                    self.config.num_queries_vis_abstractor_video_slow
                )
856
857
858
859
860
861
862
863
                grid_idx += 1
                num_grids.append(grid_idx)
        else:
            # slowfast
            for pixel_values_frames in pixel_values_videos:
                for pixel_values_frame in pixel_values_frames:
                    if len(pixel_values_frame) > 0:
                        num_queries_vis_abstractors.append(
864
865
                            self.config.num_queries_vis_abstractor_video_slow
                        )
866
867
868
                        grid_idx += 1
                        num_grids.append(grid_idx)
                        num_queries_vis_abstractors.append(
869
870
                            self.config.num_queries_vis_abstractor_video_fast
                        )
871
872
873
                        grid_idx = grid_idx + len(pixel_values_frame) - 1
                        num_grids.append(grid_idx)

874
875
876
        video_forward_outs = self.mm_projector(
            video_forward_outs, num_queries_vis_abstractors, num_grids
        )
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897

        video_features = []  # what we want to return
        target_features = []
        target_group_size = 0
        group_counter = 0
        video_groups = [
            len(frame) for frames in pixel_values_videos for frame in frames
        ]  # for concat video features after projector

        for forward_out in video_forward_outs:
            target_group_size += len(forward_out)
            target_features.append(forward_out.flatten(0, 1))

            video_group_size = video_groups[group_counter]
            if video_group_size == target_group_size:
                video_features.append(torch.cat(target_features, dim=0))
                target_features = []
                group_counter += 1
                target_group_size = 0

            elif video_group_size < target_group_size:
898
                raise RuntimeError(f"{video_group_size=} < {target_group_size=}")
899

900
901
902
        assert len(target_features) == 0, (
            f"target_features is not empty!! {target_features}"
        )
903
904
        assert len(video_groups) == len(video_features)

905
906
907
        feats_per_video = [len(video) for video in pixel_values_videos]
        idxs_per_video = [0, *accumulate(feats_per_video)]
        return tuple(
908
909
910
            torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]])
            for i in range(len(feats_per_video))
        )
911
912
913
914
915
916
917
918

    def _prepare_multimodal_kwargs(self, **kwargs: object):
        output = defaultdict(list)
        for k, v in kwargs.items():
            if len(v) < 1 or len(v[0]) < 1:
                continue  # if empty batch of empty sample

            new_k, is_video = k, False
919
            if not k.endswith("_images") and not k.endswith("_videos"):
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
                pass
            else:
                new_k, is_video = k.split("_")[:-1], k.split("_")[-1]
                new_k = "_".join(new_k)
                is_video = is_video == "videos"

            for _sample_idx, _v in enumerate(v):  # batch -> sample
                if new_k not in ["pixel_values"]:
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                    _v = _v.detach().cpu().numpy().tolist()
                    output[new_k][_sample_idx] += _v
                elif isinstance(_v, torch.Tensor):
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                        output["is_videos"].append(list())
                    _v = list(torch.unbind(_v, dim=0))
                    output[new_k][_sample_idx] += _v
                    output["is_videos"][_sample_idx] += [
                        is_video,
                    ] * len(_v)
        return dict(output)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
946
    ) -> torch.Tensor | None:
947
        return self.language_model.compute_logits(hidden_states)
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def _init_possible_resolutions(
        self,
        config,
        vision_config,
    ):
        if not getattr(config, "possible_resolutions", []):
            possible_resolutions = []
            if config.anyres:
                assert config.max_num_grids > 0
                for i in range(1, config.max_num_grids + 1):
                    for j in range(1, config.max_num_grids + 1):
                        if i == 1 and j == 1 and not config.use_1x1_grid:
                            continue
                        if i * j <= config.max_num_grids:
                            possible_resolutions.append([i, j])

972
973
974
975
                possible_resolutions = [
                    [ys * vision_config.image_size, xs * vision_config.image_size]
                    for ys, xs in possible_resolutions
                ]
976
977
978
979
980
981
982
983
984
985
986
987
            return possible_resolutions
        else:
            return config.possible_resolutions

    def _init_mm_projector(
        self,
        config,
        text_config,
        vision_config,
    ):
        input_hidden_size = vision_config.hidden_size
        if config.mm_projector_type == "linear":
988
            mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
989
990
991
992
            mm_projector.dtype = next(mm_projector.parameters()).dtype
        elif config.mm_projector_type == "cabstractor":
            mm_projector = HCXVisionCAbstractor(
                num_queries=config.num_queries_vis_abstractor_image,
993
994
                num_input_tokens=(vision_config.image_size // vision_config.patch_size)
                ** 2,
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
                encoder_hidden_size=input_hidden_size,
                hidden_size=input_hidden_size,
                output_hidden_size=text_config.hidden_size,
                pos_emb=config.proj_pos_emb,
                prenorm=config.proj_prenorm,
            )
        else:
            mm_projector = HCXVisionMlp(
                config.mm_projector_type,
                input_hidden_size,
                hidden_features=input_hidden_size,
                out_features=self.text_config.hidden_size,
            )
        return mm_projector


1011
def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor:
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
1022
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
1023
1024
1025
1026
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
1027
        unpadded_tensor = tensor[:, :, padding : current_width - padding]
1028
1029
1030
1031

    return unpadded_tensor


1032
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
1033
1034
1035
1036
1037
1038
1039
    original_height, original_width = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for height, width in possible_resolutions:
        scale = min(width / original_width, height / original_height)
1040
1041
1042
1043
1044
1045
1046
        downscaled_width, downscaled_height = (
            int(original_width * scale),
            int(original_height * scale),
        )
        effective_resolution = min(
            downscaled_width * downscaled_height, original_width * original_height
        )
1047
1048
1049
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (
1050
1051
1052
            effective_resolution == max_effective_resolution
            and wasted_resolution < min_wasted_resolution
        ):
1053
1054
1055
1056
1057
1058
1059
1060
1061
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (height, width)

    return best_fit


def get_anyres_image_grid_shape(
    image_size: tuple[int, int],
1062
    grid_pinpoints: str | list[tuple[int, int]],
1063
1064
    patch_size: int,
) -> tuple[int, int]:
1065
1066
1067
1068
1069
    possible_resolutions = (
        grid_pinpoints
        if isinstance(grid_pinpoints, list)
        else ast.literal_eval(grid_pinpoints)
    )
1070
1071

    original_width, original_height = image_size
1072
1073
1074
    height, width = select_best_resolution(
        (original_height, original_width), possible_resolutions
    )
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    return width // patch_size, height // patch_size


def reshape_and_unpad_image_features(
    image_feature: torch.Tensor,
    height: int,
    width: int,
    image_size: tuple[int, int],
    possible_resolutions: list[tuple[int, int]],
    grid_size: int,
    unpad: bool,
    image_newline: torch.Tensor,
) -> torch.Tensor:
    base_image_feature = image_feature[0]
    image_feature = image_feature[1:]

1091
    assert height * width == base_image_feature.shape[0], (
1092
1093
        f"{height=} * {width=} != {base_image_feature.shape[0]=}"
    )
1094
1095

    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
1096
1097
1098
1099
1100
        image_size, possible_resolutions, grid_size
    )
    image_feature = image_feature.view(
        num_patch_height, num_patch_width, height, width, -1
    )
1101
1102
1103
1104
1105
1106
1107
1108

    if unpad:
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        image_feature = unpad_image(image_feature, image_size)
        image_feature = torch.cat(
            (
                image_feature,
1109
1110
1111
                image_newline[:, None, None]
                .expand(*image_feature.shape[:-1], 1)
                .to(image_feature.device),
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
            ),
            dim=-1,
        )
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
    else:
        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
        image_feature = image_feature.flatten(0, 3)
    image_feature = torch.cat((base_image_feature, image_feature), dim=0)

    return image_feature


def anyres_postprocessing(
1125
    image_forward_outs: list[torch.Tensor],
1126
1127
1128
1129
    image_sizes: list[list[int]],
    possible_resolutions: list[tuple[int, int]],
    patch_size: int,
    grid_size: int,
1130
    image_newline: torch.Tensor,
1131
1132
    num_queries_vis_abstractor: int = -1,
    unpad: bool = False,
1133
) -> list[torch.Tensor]:
1134
1135
1136
    height = width = grid_size // patch_size

    if num_queries_vis_abstractor > 0:
1137
1138
1139
        assert (num_queries_vis_abstractor**0.5).is_integer(), (
            "n_queries must be square number"
        )
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        height = width = int(num_queries_vis_abstractor**0.5)

    # post-processing (unpad, add newline)
    new_image_features = []
    for image_idx, image_feature in enumerate(image_forward_outs):
        if image_feature.shape[0] > 1:
            image_feature = reshape_and_unpad_image_features(
                image_feature=image_feature,
                height=height,
                width=width,
                image_size=image_sizes[image_idx],
                possible_resolutions=possible_resolutions,
                grid_size=grid_size,  # Pass grid info if needed by helper
                unpad=unpad,
                image_newline=image_newline,
            )
        else:
            image_feature = image_feature[0]
            image_feature = torch.cat(
1159
1160
                (image_feature, image_newline[None].to(image_feature.device)), dim=0
            )
1161
        new_image_features.append(image_feature)
1162
1163

    return new_image_features