hyperclovax_vision.py 39.4 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
import ast
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
8
from itertools import accumulate
9
from typing import Annotated, Literal
10
11
12
13

import numpy as np
import torch
import torch.nn as nn
14
from einops import rearrange
15
16
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
17
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
18
19

from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.multimodal.cache import BaseMultiModalProcessorCache
24
25
26
27
28
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
29
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
30
from vllm.multimodal.processing import (
31
    BaseDummyInputsBuilder,
32
33
34
35
36
37
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
)
38
from vllm.sequence import IntermediateTensors
39
from vllm.utils.tensor_schema import TensorSchema, TensorShape
40
41
42
43

from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
44
45
46
47
48
49
from .utils import (
    AutoWeightsLoader,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
50
51
52
53
54
55
from .vision import get_vision_encoder_info

IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"


56
57
58
# Based on combine_frames_into_images in
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
def get_num_combined_frames(
59
60
    num_frames: int,
    max_grid_shape: tuple[int, int] = (3, 3),
61
62
63
64
65
66
67
68
69
70
) -> int:
    max_num_grids = max_grid_shape[0] * max_grid_shape[1]

    # Calculate the number of canvases needed.
    num_canvases = num_frames // max_num_grids
    leftover_frames = num_frames % max_num_grids

    return num_canvases + (leftover_frames > 0)


71
class HCXVisionImagePixelInputs(TensorSchema):
72
    """
73
74
75
76
77
78
    Dimensions:
        - n: Number of images
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
79
    """
80

81
82
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values_images: Annotated[
83
84
        list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})
    ]
85
86
87
88
89
90
91
    image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]


HCXVisionImageInputs = HCXVisionImagePixelInputs


class HCXVisionVideoPixelInputs(TensorSchema):
92
    """
93
94
95
96
97
98
99
    Dimensions:
        - n: Number of videos
        - f: Number of frames
        - g: Number of grids
        - c: Number of channels (3)
        - h: Height
        - w: Width
100
    """
101

102
103
104
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
    pixel_values_videos: Annotated[
        list[list[torch.Tensor]],
105
106
        TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}),
    ]
107
108


109
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
110
111
112
113
114
115


class HCXVisionProcessingInfo(BaseProcessingInfo):
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

116
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
117
118
119
120
121
        return {"image": None, "video": None}

    def get_num_image_tokens(
        self,
        *,
122
        vision_query_length: int | list[int],
123
124
125
126
127
128
129
130
131
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_num_video_tokens(
        self,
        *,
132
        vision_query_length: int | list[int],
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    ) -> int:
        if isinstance(vision_query_length, int):
            return vision_query_length
        else:
            return sum(vision_query_length)

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


153
class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
154
155
156
157
158
    def get_dummy_text(
        self,
        mm_counts: Mapping[str, int],
    ) -> str:
        dummy_text = IMAGE_TOKEN * mm_counts.get(
159
160
            "image", 0
        ) + VIDEO_TOKEN * mm_counts.get("video", 0)
161
162
163
164
165
166
        return dummy_text

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
167
        mm_options: Mapping[str, BaseDummyOptions],
168
169
170
171
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

172
        target_width, target_height = self.info.get_image_size_with_most_features()
173
        target_num_frames = 32
174

175
176
        image_overrides = mm_options.get("image")
        video_overrides = mm_options.get("video")
177

178
        return {
179
            "image": self._get_dummy_images(
180
181
182
                width=target_width,
                height=target_height,
                num_images=num_images,
183
                overrides=image_overrides,
184
            ),
185
            "video": self._get_dummy_videos(
186
187
188
189
                width=target_width - 1,
                height=target_height - 1,
                num_frames=target_num_frames,
                num_videos=num_videos,
190
                overrides=video_overrides,
191
            ),
192
193
194
        }


195
class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]):
196
197
198
199
200
201
202
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
203
        for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
204
205
            if video_arr.dtype != np.uint8:
                mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
206
207
208
209
210
211
212
213
214
215
216

        processed_outputs = self.info.ctx.call_hf_processor(
            hf_processor=self.info.get_hf_processor(**mm_kwargs),
            data=dict(
                text=prompt,
                images=None,
                videos=None,
            ),
        )  # text-only

        if len(mm_data) > 0:
217
218
            images = mm_data.get("images")
            videos = mm_data.get("videos")
219

220
            # batchify input as a single item
221
222
223
224
            _processed_outputs = self.info.ctx.call_hf_processor(
                hf_processor=self.info.get_hf_processor(**mm_kwargs),
                data=dict(
                    text=None,
225
226
                    images=None if images is None else [images],
                    videos=None if videos is None else [videos],
227
228
229
230
                ),
            )  # mm-only

            for k, v in _processed_outputs.items():
231
232
                if isinstance(v, list) and len(v) > 0:
                    assert len(v) == 1
233
                    _processed_outputs[k] = v[0]
234
235

            if images:
236
                _processed_outputs["image_sizes_images"] = torch.tensor(
237
238
239
240
241
                    _processed_outputs["image_sizes_images"]
                )
                _processed_outputs["vision_query_lengths_images"] = torch.tensor(
                    _processed_outputs["vision_query_lengths_images"]
                )
242

243
            if videos:
244
                _idx_per_video = [
245
246
247
248
                    0,
                    *accumulate(
                        get_num_combined_frames(len(video)) for video in videos
                    ),
249
250
                ]
                _processed_outputs["pixel_values_videos"] = [
251
252
253
                    _processed_outputs["pixel_values_videos"][
                        _idx_per_video[i] : _idx_per_video[i + 1]
                    ]
254
                    for i in range(len(videos))
255
256
                ]
                _processed_outputs["vision_query_lengths_videos"] = [
257
                    torch.tensor(
258
259
260
261
                        _processed_outputs["vision_query_lengths_videos"][
                            _idx_per_video[i] : _idx_per_video[i + 1]
                        ]
                    )
262
                    for i in range(len(videos))
263
264
                ]

265
266
267
268
            processed_outputs.update(_processed_outputs)

        return processed_outputs

269
270
271
272
273
274
275
276
277
    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

278
279
280
281
    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
282
        out_mm_kwargs: MultiModalKwargsItems,
283
284
285
286
287
288
289
290
291
292
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        placeholder = {
            "image": hf_config.image_token_id,
            "video": hf_config.video_token_id,
        }

        def get_replacement_hyperclovax(
            item_idx: int,
            modality: str,
293
            out_mm_kwargs: MultiModalKwargsItems,
294
        ):
295
296
            out_item = out_mm_kwargs[modality][item_idx]

297
            if modality == "image":
298
                lens = out_item["vision_query_lengths_images"].data.tolist()
299
                num_tokens = self.info.get_num_image_tokens(vision_query_length=lens)
300
            elif modality == "video":
301
                lens = out_item["vision_query_lengths_videos"].data.tolist()
302
                num_tokens = self.info.get_num_video_tokens(vision_query_length=lens)
303
304
305
306
            else:
                raise NotImplementedError(modality)

            return [placeholder[modality]] * num_tokens
307
308
309
310
311
312
313
314
315
316
317
318

        return [
            PromptReplacement(
                modality=modality,
                target=[
                    placeholder[modality],
                ],
                replacement=partial(
                    get_replacement_hyperclovax,
                    modality=modality,
                    out_mm_kwargs=out_mm_kwargs,
                ),
319
320
            )
            for modality in ("image", "video")
321
322
323
324
325
326
327
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
328
        fields = dict(
329
330
331
332
333
334
335
            pixel_values_images=MultiModalFieldConfig.batched("image"),
            image_sizes_images=MultiModalFieldConfig.batched("image"),
            vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.batched("video"),
            vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
        )

336
337
        return fields

338
339

def _build_hcxvision_hf_info(
340
341
    ctx: InputProcessingContext,
) -> HCXVisionProcessingInfo:
342
343
344
345
346
347
348
    return HCXVisionProcessingInfo(ctx)


def _build_hcxvision_hf_processor(
    info: HCXVisionProcessingInfo,
    dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
    *,
349
    cache: BaseMultiModalProcessorCache | None = None,
350
351
352
353
354
355
356
357
358
359
360
361
362
) -> BaseMultiModalProcessor:
    if isinstance(info, HCXVisionProcessingInfo):
        return HCXVisionMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    raise NotImplementedError(type(info))


def init_vision_tower_for_hcxvision(
    vision_config,
363
    quant_config: QuantizationConfig | None,
364
    *,
365
366
    use_nth_layer: int | None = None,
    require_post_norm: bool | None = None,
367
    prefix: str = "",
368
) -> CLIPVisionModel | SiglipVisionModel:
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    num_hidden_layers = vision_config.num_hidden_layers
    if not isinstance(use_nth_layer, int):
        pass
    elif use_nth_layer >= 0:
        num_hidden_layers = use_nth_layer + 1
    else:
        num_hidden_layers = num_hidden_layers + use_nth_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
            prefix=prefix,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


class HCXVisionMlp(nn.Module):
    def __init__(
        self,
        mm_projector_type,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mm_projector_type = mm_projector_type
        if self.mm_projector_type == "mlp":
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
        elif self.mm_projector_type == "inverted_mlp":
            self.fc1 = nn.Linear(in_features, 2 * hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(2 * hidden_features, out_features)
        else:
420
421
422
            raise NotImplementedError(
                "{} is not implemented".format(self.mm_projector_type)
            )
423
424
425
426
427
428
429
430
431
432
433

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class HCXVisionCAbstractor(nn.Module):
    """
    This module is based on C-Abstractor, whose license is under apache-2.0.
434
    You can check the original code at
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
    and we made necessary modifications.
    """

    def __init__(
        self,
        num_queries: int,
        num_input_tokens: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        pos_emb: bool = True,
        prenorm: bool = False,
    ):
        super().__init__()
        self.num_input_tokens = num_input_tokens
        self.output_hidden_size = output_hidden_size

        # Positional embedding
        if pos_emb:
            self.pos_emb = torch.nn.Parameter(
456
457
                torch.zeros(1, num_input_tokens, encoder_hidden_size)
            )
458
459
460
461
462
463
464
465
466
467
            self.pos_emb.data.normal_(mean=0.0, std=0.02)
        else:
            self.pos_emb = None

        # (Optional) Pre-normalization layer
        if prenorm:
            self.prenorm = LayerNorm(encoder_hidden_size)
        else:
            self.prenorm = None

468
469
470
        self.build_net(
            num_queries, encoder_hidden_size, hidden_size, output_hidden_size
        )
471
472
473
474
475
        self.dtype = next(self.parameters()).dtype

    def forward(
        self,
        x: torch.Tensor,
476
477
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    ) -> torch.Tensor:
        if self.prenorm is not None:
            x = self.prenorm(x)

        if self.pos_emb is not None:
            x = x + self.pos_emb

        x = self._forward(
            x,
            num_queries_vis_abstractors=num_queries_vis_abstractors,
            num_grids=num_grids,
        )  # (B, L, output_hidden_size)

        return x

    def _forward(
        self,
        x: torch.Tensor,
496
497
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
498
499
500
501
502
503
504
505
506
    ) -> torch.Tensor:
        # x: [B, L, dim]
        B, L, dim = x.shape
        hw = int(L**0.5)
        x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)

        if num_queries_vis_abstractors is not None:
            assert num_grids is not None
            return self._forward_adaptive_num_query(
507
508
                x, num_queries_vis_abstractors, num_grids
            )
509
510
511
512
513
514
515
516
517

        x = self.net(x)
        x = rearrange(x, "b d h w -> b (h w) d")
        x = self.readout(x)
        return x

    def _forward_adaptive_num_query(
        self,
        x: torch.Tensor,
518
519
        num_queries_vis_abstractors: list[list[int]] | None = None,
        num_grids: list[int] | None = None,
520
521
522
523
524
525
526
527
528
    ) -> list[torch.Tensor]:
        # self.net is consisted by 3 layers (s1, sampler, s2)
        assert len(self.net) == 3

        x = self.net[0](x)  # s1
        new_x = []
        for i, num_queries in enumerate(num_queries_vis_abstractors):
            hw = int(num_queries**0.5)
            sampler = nn.AdaptiveAvgPool2d((hw, hw))
529
            out = sampler(x[num_grids[i] : num_grids[i + 1], :])
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
            out = self.net[2](out)  # s2

            out = rearrange(out, "b d h w -> b (h w) d")
            out = self.readout(out)

            new_x.append(out)
        return new_x

    def build_net(
        self,
        n_queries: int,
        encoder_hidden_size: int,
        hidden_size: int,
        output_hidden_size: int,
        depth: int = 3,
        mlp_depth: int = 2,
    ):
547
548
549
        assert (n_queries**0.5).is_integer(), (
            f"n_queries must be square number. n_queries: {n_queries}"
        )
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        hw = int(n_queries**0.5)

        # RegBlock = ResBlock + SE
        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = nn.AdaptiveAvgPool2d((hw, hw))
        s2 = RegBlock(
            depth,
            hidden_size,
            hidden_size,
        )

        self.net = nn.Sequential(s1, sampler, s2)
574
        self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

    def build_mlp(
        self,
        depth: int,
        hidden_size: int,
        output_hidden_size: int,
    ):
        layers = [nn.Linear(hidden_size, output_hidden_size)]
        for _ in range(1, depth):
            layers.append(nn.SiLU())
            layers.append(nn.Linear(output_hidden_size, output_hidden_size))
        return nn.Sequential(*layers)


@MULTIMODAL_REGISTRY.register_processor(
    _build_hcxvision_hf_processor,
    info=_build_hcxvision_hf_info,
592
593
    dummy_inputs=HCXVisionDummyInputsBuilder,
)
594
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
595
596
597
598
599
600
601
602
603
    """
    HyperCLOVAX-SEED Vision-Language Model (V1 architecture).

    Supports:
    - HyperCLOVAX-SEED-Vision-Instruct-3B

    Uses CLIP/SigLIP as the vision encoder with C-Abstractor projector.
    """

604
605
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
606
        "gate_up_proj": ["gate_proj", "up_proj"],
607
608
    }

609
610
611
612
613
614
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        super().__init__()

        # init configs
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        # text_config
        text_config = config.text_config
        if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
            text_config._attn_implementation = "sdpa"
        if text_config.model_type != "hyperclovax":
            text_config.logits_scaling = 1.0
        # vision_config
        vision_config = config.vision_config
        vision_config.auto_map = {}
        vision_config.anyres = config.anyres
        vision_config.max_num_grids = config.max_num_grids
        self.dtype = vllm_config.model_config.dtype

        ## possible_resolution should be matched with preprocessor_config.json
        config.possible_resolutions = self._init_possible_resolutions(
635
636
            config, vision_config
        )
637

638
        with self._mark_tower_model(vllm_config, {"image", "video"}):
639
640
            self.vision_model = init_vision_tower_for_hcxvision(
                vision_config,
641
                quant_config=quant_config,
642
643
644
645
                use_nth_layer=getattr(config, "use_nth_layer", -1),
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
646
647
648
            self.mm_projector = self._init_mm_projector(
                config, text_config, vision_config
            )
649

650
651
652
653
            if config.anyres:
                self.image_newline = nn.Parameter(
                    torch.empty(text_config.hidden_size, dtype=self.dtype)
                )
654

655
656
657
658
659
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=text_config,
                prefix=maybe_prefix(prefix, "language_model"),
660
            )
661
662
663
664
665

        self.config = config
        self.vision_config = vision_config
        self.text_config = text_config

666
667
668
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )
669
670

    @classmethod
671
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
672
673
674
675
676
677
678
        if modality.startswith("image"):
            return IMAGE_TOKEN
        if modality.startswith("video"):
            return VIDEO_TOKEN

        raise ValueError("Only image or video modality is supported")

679
680
681
    def _parse_and_validate_image_input(
        self,
        **kwargs: object,
682
    ) -> HCXVisionImageInputs | None:
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        pixel_values_images = kwargs.pop("pixel_values_images", None)

        if pixel_values_images is None:
            return None

        image_sizes_images = kwargs.pop("image_sizes_images")

        return HCXVisionImagePixelInputs(
            pixel_values_images=pixel_values_images,
            image_sizes_images=image_sizes_images,
        )

    def _parse_and_validate_video_input(
        self,
        **kwargs: object,
698
    ) -> HCXVisionVideoInputs | None:
699
700
701
702
703
704
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)

        if pixel_values_videos is None:
            return None

        return HCXVisionVideoPixelInputs(
705
706
            pixel_values_videos=pixel_values_videos,
        )
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

    def _process_image_input(
        self,
        image_input: HCXVisionImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_images(
            pixel_values_images=image_input["pixel_values_images"],
            image_sizes_images=image_input["image_sizes_images"],
        )

    def _process_video_input(
        self,
        video_input: HCXVisionVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        return self.forward_videos(
722
723
            pixel_values_videos=video_input["pixel_values_videos"],
        )
724
725
726
727
728
729
730

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
731
732
733
734
            if input_key == "pixel_values_images" and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key == "pixel_values_videos" and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
735
736
737

        return modalities

738
    def embed_multimodal(
739
        self,
740
        **kwargs: object,
741
    ) -> MultiModalEmbeddings:
742
743
744
745
746
747
748
749
750
751
752
753
754
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
755
756
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
757
758
759
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
760
                multimodal_embeddings += tuple(video_embeddings)
761
762
763
764
765

        return multimodal_embeddings

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
766
        input_ids: torch.Tensor | None,
767
        positions: torch.Tensor,
768
769
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
770
        **kwargs: object,
771
    ) -> torch.Tensor | IntermediateTensors:
772
773
774
        if intermediate_tensors is not None:
            inputs_embeds = None

775
776
777
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
778
779
780
781
        return hidden_states

    def forward_images(
        self,
782
783
784
785
        pixel_values_images: list[torch.Tensor],
        image_sizes_images: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
786
787

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
788
789
790
        image_forward_outs = self.vision_model(pixel_values_image_flat)[
            :, visual_token_idx:
        ]
791

792
        image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
793
794
        image_forward_outs = self.mm_projector(image_forward_outs)  # b (h w) d

795
        split_sizes = [len(item) for item in pixel_values_images]
796
        image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0)
797
798
799
800

        # newline for anyres postprocessing
        image_features = anyres_postprocessing(
            image_forward_outs=image_forward_outs,
801
            image_sizes=image_sizes_images.tolist(),
802
            num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image,
803
804
805
806
807
808
            unpad=self.config.unpad,
            patch_size=self.vision_config.patch_size,
            grid_size=self.vision_config.image_size,
            image_newline=self.image_newline,
            possible_resolutions=self.config.possible_resolutions,
        )
809
810

        return tuple(image_features)
811
812
813

    def forward_videos(
        self,
814
815
816
817
818
819
        pixel_values_videos: list[list[torch.Tensor]],
    ) -> tuple[torch.Tensor, ...]:
        pixel_values_videos_flat = flatten_bn(
            [frame for frames in pixel_values_videos for frame in frames],
            concat=True,
        )
820
821

        visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
822
823
824
        video_forward_outs = self.vision_model(pixel_values_videos_flat)[
            :, visual_token_idx:
        ]
825

826
        video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype)
827
828
829
830

        # Run MM-Projector
        # len(num_grids) == len(num_queries_vis_abstractors) + 1
        grid_idx = 0
831
832
833
834
        # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
        num_grids = [grid_idx]
        # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
        num_queries_vis_abstractors = []
835
836
837
838
839
840
841
        len_total_frames = video_forward_outs.shape[0]

        if self.config.first_last_frames_slow:
            # slowfast (first_last_frames_slow)
            assert len_total_frames != 0
            if len_total_frames <= 2:
                num_queries_vis_abstractors.append(
842
843
                    self.config.num_queries_vis_abstractor_video_slow
                )
844
845
846
847
                grid_idx += len_total_frames
                num_grids.append(grid_idx)
            else:
                num_queries_vis_abstractors.append(
848
849
                    self.config.num_queries_vis_abstractor_video_slow
                )
850
851
852
853
                grid_idx += 1
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
854
855
                    self.config.num_queries_vis_abstractor_video_fast
                )
856
857
858
859
                grid_idx += len_total_frames - 2
                num_grids.append(grid_idx)

                num_queries_vis_abstractors.append(
860
861
                    self.config.num_queries_vis_abstractor_video_slow
                )
862
863
864
865
866
867
868
869
                grid_idx += 1
                num_grids.append(grid_idx)
        else:
            # slowfast
            for pixel_values_frames in pixel_values_videos:
                for pixel_values_frame in pixel_values_frames:
                    if len(pixel_values_frame) > 0:
                        num_queries_vis_abstractors.append(
870
871
                            self.config.num_queries_vis_abstractor_video_slow
                        )
872
873
874
                        grid_idx += 1
                        num_grids.append(grid_idx)
                        num_queries_vis_abstractors.append(
875
876
                            self.config.num_queries_vis_abstractor_video_fast
                        )
877
878
879
                        grid_idx = grid_idx + len(pixel_values_frame) - 1
                        num_grids.append(grid_idx)

880
881
882
        video_forward_outs = self.mm_projector(
            video_forward_outs, num_queries_vis_abstractors, num_grids
        )
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

        video_features = []  # what we want to return
        target_features = []
        target_group_size = 0
        group_counter = 0
        video_groups = [
            len(frame) for frames in pixel_values_videos for frame in frames
        ]  # for concat video features after projector

        for forward_out in video_forward_outs:
            target_group_size += len(forward_out)
            target_features.append(forward_out.flatten(0, 1))

            video_group_size = video_groups[group_counter]
            if video_group_size == target_group_size:
                video_features.append(torch.cat(target_features, dim=0))
                target_features = []
                group_counter += 1
                target_group_size = 0

            elif video_group_size < target_group_size:
904
                raise RuntimeError(f"{video_group_size=} < {target_group_size=}")
905

906
907
908
        assert len(target_features) == 0, (
            f"target_features is not empty!! {target_features}"
        )
909
910
        assert len(video_groups) == len(video_features)

911
912
913
        feats_per_video = [len(video) for video in pixel_values_videos]
        idxs_per_video = [0, *accumulate(feats_per_video)]
        return tuple(
914
915
916
            torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]])
            for i in range(len(feats_per_video))
        )
917
918
919
920
921
922
923
924

    def _prepare_multimodal_kwargs(self, **kwargs: object):
        output = defaultdict(list)
        for k, v in kwargs.items():
            if len(v) < 1 or len(v[0]) < 1:
                continue  # if empty batch of empty sample

            new_k, is_video = k, False
925
            if not k.endswith("_images") and not k.endswith("_videos"):
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
                pass
            else:
                new_k, is_video = k.split("_")[:-1], k.split("_")[-1]
                new_k = "_".join(new_k)
                is_video = is_video == "videos"

            for _sample_idx, _v in enumerate(v):  # batch -> sample
                if new_k not in ["pixel_values"]:
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                    _v = _v.detach().cpu().numpy().tolist()
                    output[new_k][_sample_idx] += _v
                elif isinstance(_v, torch.Tensor):
                    if len(output[new_k]) < _sample_idx + 1:
                        output[new_k].append(list())
                        output["is_videos"].append(list())
                    _v = list(torch.unbind(_v, dim=0))
                    output[new_k][_sample_idx] += _v
                    output["is_videos"][_sample_idx] += [
                        is_video,
                    ] * len(_v)
        return dict(output)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
952
    ) -> torch.Tensor | None:
953
        return self.language_model.compute_logits(hidden_states)
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def _init_possible_resolutions(
        self,
        config,
        vision_config,
    ):
        if not getattr(config, "possible_resolutions", []):
            possible_resolutions = []
            if config.anyres:
                assert config.max_num_grids > 0
                for i in range(1, config.max_num_grids + 1):
                    for j in range(1, config.max_num_grids + 1):
                        if i == 1 and j == 1 and not config.use_1x1_grid:
                            continue
                        if i * j <= config.max_num_grids:
                            possible_resolutions.append([i, j])

978
979
980
981
                possible_resolutions = [
                    [ys * vision_config.image_size, xs * vision_config.image_size]
                    for ys, xs in possible_resolutions
                ]
982
983
984
985
986
987
988
989
990
991
992
993
            return possible_resolutions
        else:
            return config.possible_resolutions

    def _init_mm_projector(
        self,
        config,
        text_config,
        vision_config,
    ):
        input_hidden_size = vision_config.hidden_size
        if config.mm_projector_type == "linear":
994
            mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
995
996
997
998
            mm_projector.dtype = next(mm_projector.parameters()).dtype
        elif config.mm_projector_type == "cabstractor":
            mm_projector = HCXVisionCAbstractor(
                num_queries=config.num_queries_vis_abstractor_image,
999
1000
                num_input_tokens=(vision_config.image_size // vision_config.patch_size)
                ** 2,
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
                encoder_hidden_size=input_hidden_size,
                hidden_size=input_hidden_size,
                output_hidden_size=text_config.hidden_size,
                pos_emb=config.proj_pos_emb,
                prenorm=config.proj_prenorm,
            )
        else:
            mm_projector = HCXVisionMlp(
                config.mm_projector_type,
                input_hidden_size,
                hidden_features=input_hidden_size,
                out_features=self.text_config.hidden_size,
            )
        return mm_projector


1017
def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor:
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
1028
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
1029
1030
1031
1032
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
1033
        unpadded_tensor = tensor[:, :, padding : current_width - padding]
1034
1035
1036
1037

    return unpadded_tensor


1038
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
1039
1040
1041
1042
1043
1044
1045
    original_height, original_width = original_size
    best_fit = None
    max_effective_resolution = 0
    min_wasted_resolution = float("inf")

    for height, width in possible_resolutions:
        scale = min(width / original_width, height / original_height)
1046
1047
1048
1049
1050
1051
1052
        downscaled_width, downscaled_height = (
            int(original_width * scale),
            int(original_height * scale),
        )
        effective_resolution = min(
            downscaled_width * downscaled_height, original_width * original_height
        )
1053
1054
1055
        wasted_resolution = (width * height) - effective_resolution

        if effective_resolution > max_effective_resolution or (
1056
1057
1058
            effective_resolution == max_effective_resolution
            and wasted_resolution < min_wasted_resolution
        ):
1059
1060
1061
1062
1063
1064
1065
1066
1067
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (height, width)

    return best_fit


def get_anyres_image_grid_shape(
    image_size: tuple[int, int],
1068
    grid_pinpoints: str | list[tuple[int, int]],
1069
1070
    patch_size: int,
) -> tuple[int, int]:
1071
1072
1073
1074
1075
    possible_resolutions = (
        grid_pinpoints
        if isinstance(grid_pinpoints, list)
        else ast.literal_eval(grid_pinpoints)
    )
1076
1077

    original_width, original_height = image_size
1078
1079
1080
    height, width = select_best_resolution(
        (original_height, original_width), possible_resolutions
    )
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
    return width // patch_size, height // patch_size


def reshape_and_unpad_image_features(
    image_feature: torch.Tensor,
    height: int,
    width: int,
    image_size: tuple[int, int],
    possible_resolutions: list[tuple[int, int]],
    grid_size: int,
    unpad: bool,
    image_newline: torch.Tensor,
) -> torch.Tensor:
    base_image_feature = image_feature[0]
    image_feature = image_feature[1:]

1097
    assert height * width == base_image_feature.shape[0], (
1098
1099
        f"{height=} * {width=} != {base_image_feature.shape[0]=}"
    )
1100
1101

    num_patch_width, num_patch_height = get_anyres_image_grid_shape(
1102
1103
1104
1105
1106
        image_size, possible_resolutions, grid_size
    )
    image_feature = image_feature.view(
        num_patch_height, num_patch_width, height, width, -1
    )
1107
1108
1109
1110
1111
1112
1113
1114

    if unpad:
        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
        image_feature = unpad_image(image_feature, image_size)
        image_feature = torch.cat(
            (
                image_feature,
1115
1116
1117
                image_newline[:, None, None]
                .expand(*image_feature.shape[:-1], 1)
                .to(image_feature.device),
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
            ),
            dim=-1,
        )
        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
    else:
        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
        image_feature = image_feature.flatten(0, 3)
    image_feature = torch.cat((base_image_feature, image_feature), dim=0)

    return image_feature


def anyres_postprocessing(
1131
    image_forward_outs: list[torch.Tensor],
1132
1133
1134
1135
    image_sizes: list[list[int]],
    possible_resolutions: list[tuple[int, int]],
    patch_size: int,
    grid_size: int,
1136
    image_newline: torch.Tensor,
1137
1138
    num_queries_vis_abstractor: int = -1,
    unpad: bool = False,
1139
) -> list[torch.Tensor]:
1140
1141
1142
    height = width = grid_size // patch_size

    if num_queries_vis_abstractor > 0:
1143
1144
1145
        assert (num_queries_vis_abstractor**0.5).is_integer(), (
            "n_queries must be square number"
        )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        height = width = int(num_queries_vis_abstractor**0.5)

    # post-processing (unpad, add newline)
    new_image_features = []
    for image_idx, image_feature in enumerate(image_forward_outs):
        if image_feature.shape[0] > 1:
            image_feature = reshape_and_unpad_image_features(
                image_feature=image_feature,
                height=height,
                width=width,
                image_size=image_sizes[image_idx],
                possible_resolutions=possible_resolutions,
                grid_size=grid_size,  # Pass grid info if needed by helper
                unpad=unpad,
                image_newline=image_newline,
            )
        else:
            image_feature = image_feature[0]
            image_feature = torch.cat(
1165
1166
                (image_feature, image_newline[None].to(image_feature.device)), dim=0
            )
1167
        new_image_features.append(image_feature)
1168

zhuwenwen's avatar
zhuwenwen committed
1169
    return new_image_features