kanana_v.py 26.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Literal, TypeAlias

import numpy as np
import regex as re
import torch
from einops import rearrange
from PIL import Image
from timm.layers import LayerNorm2d
from timm.layers.pos_embed import resample_abs_pos_embed
from timm.models.regnet import RegStage
from torch import nn
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
31
    BaseDummyInputsBuilder,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .qwen2_vl import Qwen2VisionTransformer
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix

logger = init_logger(__name__)


class KananaVImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over all images in the batch
        - cps: Number of channels * patch_size * patch_size
        - ni: Number of images
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]

    vision_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]


KananaVImageInputs: TypeAlias = KananaVImagePixelInputs


def build_pos_embeds(
    config: Qwen2VLVisionConfig,
    num_input_tokens: int,
    vision_hidden_size: int,
) -> nn.Parameter | None:
    """Build positional embeddings for the visual encoder output."""
    if config.pos_emb:
        pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
        nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
    else:
        pos_emb = None

    return pos_emb


def build_mlp(
    depth: int,
    hidden_size: int,
    output_hidden_size: int,
) -> nn.Sequential:
    """Simple SiLU-activated MLP used as a projector readout."""
    layers = [nn.Linear(hidden_size, output_hidden_size)]
    for _ in range(1, depth):
        layers.append(nn.SiLU())
        layers.append(nn.Linear(output_hidden_size, output_hidden_size))
    return nn.Sequential(*layers)


class PatchMerge(nn.Module):
    """Merge neighboring patches spatially to reduce resolution."""

    def __init__(self, merge_size: int) -> None:
        super().__init__()
        self.merge_size = merge_size

    def forward(
        self,
        x: torch.Tensor,
        channel_last: bool = False,
    ) -> torch.Tensor:
        """Merge patches by `merge_size x merge_size`."""
        if channel_last:
            x = rearrange(x, "B H W D -> B D H W")
        _, _, H, W = x.shape
        merged_x = rearrange(
            x,
            "B D (H h2) (W w2) -> B (D h2 w2) H W",
            h2=self.merge_size,
            w2=self.merge_size,
        )
        return merged_x


class DynamicCAbstractor(nn.Module):
    """Dynamic C-Abstractor based on RegNet blocks."""

    def __init__(
        self,
        config: Qwen2VLVisionConfig,
        num_input_tokens: int,
    ) -> None:
        super().__init__()
        assert hasattr(config, "merge_size"), "merge_size must be provided."
        self.config = config
        self.merge_size = config.merge_size
        self.pos_emb_size = config.pos_emb_size
        if num_input_tokens == -1:
            num_input_tokens = config.pos_emb_size
        self.num_input_tokens = num_input_tokens
        self.pos_emb = build_pos_embeds(
            config, num_input_tokens, config.encoder_hidden_size
        )
        self.build_net()

    def _load_from_state_dict(self, state_dict, *args, **kwargs) -> None:
        if not state_dict:
            return

        if self.pos_emb is not None:
            key_re = re.compile(r"[\w,.]*abstractor[\w,.]*pos_emb")
            pos_emb_key = None
            for key in state_dict:
                if key_re.match(key):
                    pos_emb_key = key
                    break

            assert pos_emb_key is not None
            # update old ckpt compatible with current code
            pos_emb = state_dict[pos_emb_key]
            if pos_emb.size(1) == self.pos_emb.size(1) + 1:
                # remove obsolete first pos emb (for cls token originally)
                state_dict[pos_emb_key] = pos_emb[:, 1:]

        super()._load_from_state_dict(state_dict, *args, **kwargs)

    def build_net(self) -> None:
        encoder_hidden_size = self.config.encoder_hidden_size
        hidden_size = self.config.hidden_size
        output_hidden_size = self.config.output_hidden_size
        depth = self.config.depth
        mlp_depth = self.config.mlp_depth

        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = PatchMerge(merge_size=self.merge_size)
        s2 = RegBlock(
            depth,
            self.merge_size**2 * hidden_size,
            hidden_size,
        )

        if depth:
            self.net = nn.ModuleList([s1, sampler, s2])
            self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
        else:
            self.net = sampler
            self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)

    def forward(
        self,
        flattened_visual_embeds: torch.Tensor,
        grid_thw: torch.Tensor,
        **unused_kwargs: object,
    ) -> BaseModelOutput:
        """Apply the dynamic abstractor over flattened visual embeddings."""
        n_token_loc = torch.prod(grid_thw, dim=1)
        split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())

        flattened_visual_embeds = []
        for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
            T, H, W = _grid_thw
            assert T == 1, "T must be 1. Video is not supported yet."
            reshaped_visual_embeds = rearrange(
                _visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
            )
            # remove temporal dim
            reshaped_visual_embeds = reshaped_visual_embeds[:, 0]

            if self.pos_emb is not None:
                # interpolate pos emb and add to visual embeds
                _local_pos_emb = resample_abs_pos_embed(
                    posemb=self.pos_emb,
                    old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
                    new_size=(H, W),
                    num_prefix_tokens=0,
                )
                _local_pos_emb = rearrange(
                    _local_pos_emb,
                    "1 (h w) d -> 1 h w d",
                    h=H,
                    w=W,
                )
                reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb

            reshaped_visual_embeds = self._forward(
                reshaped_visual_embeds,
                input_size=(H, W),
            )
            flattened_visual_embeds.append(reshaped_visual_embeds)
        reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
        return BaseModelOutput(last_hidden_state=reshaped_visual_embeds)

    def _forward(
        self,
        x: torch.Tensor,
        input_size: tuple[int, int],
    ) -> torch.Tensor:
        h, w = input_size
        x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
        if self.config.depth:
            x = self.net[0](x)
            x = self.net[1](x)
            x = self.net[2](x)
        else:
            # When depth=0, self.net is a single PatchMerge module
            x = self.net(x)
        x = rearrange(x, "1 d h w -> (h w) d")
        x = self.readout(x)
        return x


class CustomQwen2VLVE(Qwen2VisionTransformer):
    """Thin wrapper around the Qwen2-VL used as a vision encoder.

    This mirrors the original HF-based vision encoder used in Kanana-V, but
    reuses vLLM's optimized `Qwen2VisionTransformer` building blocks.
    """

    def __init__(self, config: Qwen2VLVisionConfig) -> None:
        super().__init__(
            vision_config=config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
            quant_config=None,
            prefix="",
        )

        # Kanana-V uses its own projector/abstractor instead of the Qwen2
        # built-in patch merger, so we drop the merger module to keep the
        # parameter set compatible with the original checkpoint.
        if hasattr(self, "merger"):
            del self.merger

    @classmethod
    def _from_config(cls, config: Qwen2VLVisionConfig) -> "CustomQwen2VLVE":
        """Drop-in replacement for the HF `_from_config` constructor."""
        return cls(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> tuple | BaseModelOutput:
        """Run the vision transformer and optionally return intermediate states.

        Unlike the base `Qwen2VisionTransformer`, this wrapper exposes the
        pre-merger patch-level representations and a HF-style `BaseModelOutput`
        so that the existing projector / abstractor code can be reused.
        """
        assert return_dict, "Only return_dict=True is supported."

        # Patchify
        x = pixel_values.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)  # (num_patches, embed_dim)

        # Prepare grid and rotary embeddings – mirror base implementation.
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw_np = np.array(grid_thw, dtype=np.int32)
        else:
            grid_thw_list = grid_thw.tolist()
            grid_thw_np = grid_thw.cpu().numpy()

        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

        # Compute cu_seqlens in numpy then move to device, same as base model.
        cu_seqlens = np.repeat(
            grid_thw_np[:, 1] * grid_thw_np[:, 2],
            grid_thw_np[:, 0],
        ).cumsum(axis=0, dtype=np.int32)
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens).to(
            self.device,
            non_blocking=True,
        )

        # Shape to (S, B, D) with batch dimension 1 as expected by the blocks.
        x = x.unsqueeze(1)

        # Pre-compute seqlens for attention backend.
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)

        encoder_states = () if output_hidden_states else None

        for blk in self.blocks:
            if output_hidden_states:
                # Store patch-level states (S, D).
                encoder_states = encoder_states + (x.squeeze(1),)

            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
                max_seqlen=max_seqlen,
            )

        # Final hidden state at patch level (S, D).
        hidden_states = x.squeeze(1)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )

    def get_num_tokens(self) -> int:
        # Not used in the current Kanana-V pipeline, kept for API compatibility.
        return -1


class KananaVProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        max_image_size, _ = self._get_vision_info(
            image_width=9999,
            image_height=9999,
            num_frames=1,
        )
        return max_image_size

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
    ) -> tuple[ImageSize, int]:
        image_processor = self.ctx.get_hf_processor().image_processor
        smart_resize = resolve_obj_by_qualname(
            f"{type(image_processor).__module__}.smart_resize"
        )

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width, height=image_height)

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        target_width, target_height = self.get_image_size_with_most_features()
        num_vision_tokens = self._get_vision_info(
            image_width=target_width,
            image_height=target_height,
            num_frames=1,
        )[1]
        return {"image": num_vision_tokens}


class KananaVDummyInputsBuilder(BaseDummyInputsBuilder[KananaVProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return "<image>" * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
448
        mm_processor_kwargs: Mapping[str, object] | None = None,
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        return {
            "image": self._get_dummy_images(
                width=9999, height=9999, num_images=num_images
            ),
        }


class KananaVMultiModalProcessor(BaseMultiModalProcessor[KananaVProcessingInfo]):
    """vLLM multimodal processor for Kanana-V (text + image)."""

    @property
    def media_token_id(self) -> int:
        return self.info.get_hf_config().text_config.eos_token_id + 1

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        """Run the underlying HF processor on text and image data."""
        # Text-only input is handled as a special case here.
        if not mm_data or not mm_data.get("images", []):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        # Images
        image_inputs = mm_data.get("images", [])
        pixel_sizes = []
        if not isinstance(image_inputs[0], Image.Image):
            image_inputs = [Image.fromarray(image) for image in image_inputs]

        image_processor = self.info.get_hf_processor().image_processor
        processor_output = [image_processor(image) for image in image_inputs]
        pixel_values = [o["pixel_values"] for o in processor_output]
        image_meta = [o["image_meta"] for o in processor_output]
        # list of dict -> dict of list
        image_meta = {k: [d[k] for d in image_meta] for k in image_meta[0]}

        for pixel_value in pixel_values:
            pixel_sizes.append(pixel_value.shape[0])
        # flattened pixel_values for single example (already includes batch dim)
        pixel_values = torch.concat(pixel_values, dim=0)

        tokenizer = self.info.get_tokenizer()
        media_token = tokenizer.convert_ids_to_tokens([self.media_token_id])[0]
        prompt_replaced = prompt.replace("<image>", media_token)
        input_ids = tokenizer.encode(prompt_replaced)
        input_ids = torch.tensor(input_ids)

        # Ensure HF output is consistent with vLLM prompt-update expectations:
        # if the HF tokenizer emits exactly 1 placeholder token per image, expand
        # it to `T*H*W` placeholder tokens per image so placeholder detection works.
        num_images = len(image_inputs)
        image_token_thw = torch.tensor(image_meta["image_token_thw"])
        per_image_token_counts = image_token_thw.prod(dim=1).tolist()
        expected_total = int(sum(int(x) for x in per_image_token_counts))

        n_placeholders = int((input_ids == self.media_token_id).sum().item())
        if n_placeholders == num_images and expected_total != num_images:
            expanded: list[int] = []
            img_i = 0
            for tok in input_ids.tolist():
                if tok == self.media_token_id and img_i < num_images:
                    expanded.extend(
                        [self.media_token_id] * int(per_image_token_counts[img_i])
                    )
                    img_i += 1
                else:
                    expanded.append(tok)
            input_ids = input_ids.new_tensor(expanded)

        combined_outputs = dict(
            # Add batch dimension to input_ids.
            input_ids=input_ids.unsqueeze(0),
            pixel_values=pixel_values,
            vision_grid_thw=torch.tensor(image_meta["vision_grid_thw"]),
            image_token_thw=torch.tensor(image_meta["image_token_thw"]),
            pixel_sizes=torch.tensor(pixel_sizes),
        )
        return BatchFeature(combined_outputs, tensor_type="pt")

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        def get_replacement(idx: int) -> Sequence[int]:
            out_item = out_mm_kwargs["image"][idx]
            image_token_thw = out_item["image_token_thw"].data
            assert isinstance(image_token_thw, torch.Tensor)

            num_tokens = int(image_token_thw.prod().item())
            return [self.media_token_id] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement,
            ),
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        pixel_sizes = hf_inputs.get("pixel_sizes", torch.empty(0))

        mm_fields_config = dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", pixel_sizes),
            vision_grid_thw=MultiModalFieldConfig.batched("image"),
            image_token_thw=MultiModalFieldConfig.batched("image"),
        )
        return mm_fields_config


@MULTIMODAL_REGISTRY.register_processor(
    KananaVMultiModalProcessor,
    info=KananaVProcessingInfo,
    dummy_inputs=KananaVDummyInputsBuilder,
)
class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"
        else:
            raise ValueError(f"Unsupported modality: {modality}")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        self.config = config

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
            self.abstractor = DynamicCAbstractor(
                config.projector_config,
                num_input_tokens=self.vision_model.get_num_tokens(),
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "model"),
                architectures=["LlamaForCausalLM"],
            )

605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> KananaVImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        vision_grid_thw = kwargs.pop("vision_grid_thw", None)

        if pixel_values is None:
            return None

        if vision_grid_thw is None:
            raise ValueError(
                "vision_grid_thw is required when pixel_values is provided"
            )

        # Normalize pixel_values to 2D tensor (num_patches, channels*patch*patch)
        if isinstance(pixel_values, torch.Tensor):
            if pixel_values.ndim == 2:
                pass  # Already in expected shape
            elif pixel_values.ndim == 3:
                pixel_values = pixel_values.flatten(0, 1)
            else:
                raise ValueError(
                    f"pixel_values should be 2D or batched 3D tensor. "
                    f"Got ndim: {pixel_values.ndim} "
                    f"(shape={pixel_values.shape})"
                )
        else:
            pixel_values = torch.concat(pixel_values)

        # Normalize vision_grid_thw to 2D tensor (num_images, 3)
        if isinstance(vision_grid_thw, torch.Tensor):
            if vision_grid_thw.ndim == 3:
                vision_grid_thw = vision_grid_thw.flatten(0, 1)
        else:
            vision_grid_thw = torch.concat(vision_grid_thw)

        return KananaVImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            vision_grid_thw=vision_grid_thw,
        )

    def _process_image_input(self, image_input: KananaVImageInputs) -> torch.Tensor:
        pixel_values = image_input["pixel_values"]
        vision_grid_thw = image_input["vision_grid_thw"]

        image_metas = {"vision_grid_thw": vision_grid_thw}
        visual_embeds = self.forward_and_project_vision(pixel_values, image_metas)

        merge_size = self.abstractor.merge_size
        batch_size = vision_grid_thw.size(0)
        multi_modal_embeddings: tuple[torch.Tensor, ...] = ()
        sample_index = 0
        for i in range(batch_size):
            t, h, w = (
                vision_grid_thw[i][0],
                vision_grid_thw[i][1] // merge_size,
                vision_grid_thw[i][2] // merge_size,
            )
            num_tokens = t * h * w
            visual_embed = visual_embeds[sample_index : sample_index + num_tokens]
            multi_modal_embeddings += (visual_embed,)
            sample_index += num_tokens

        return multi_modal_embeddings

    def _get_visual_feature_at(
        self,
        v_output: Sequence[torch.Tensor],
        layer_index: int | Sequence[int],
    ) -> torch.Tensor:
        if isinstance(layer_index, (list, tuple)):
            visual_features = torch.stack(v_output, dim=1)[
                :, layer_index
            ]  # [B, n_scales, L, dim]
        else:
            visual_features = v_output[layer_index]  # [B, L, dim]
        return visual_features

    def forward_vision(
        self,
        pixel_values: torch.Tensor,
        image_metas: dict | None = None,
    ) -> torch.Tensor:
        vision_model_args = {
            "pixel_values": pixel_values,
            "return_dict": True,
            "output_hidden_states": True,
            "grid_thw": image_metas["vision_grid_thw"],
        }
        v_outputs = self.vision_model(**vision_model_args)
        layer_index = self.config.projector_config.feature_layer_index
        visual_features = self._get_visual_feature_at(
            v_outputs.hidden_states, layer_index
        )
        return visual_features

    def forward_projector(
        self,
        visual_features: torch.Tensor,
        image_metas: dict | None = None,
    ) -> torch.Tensor:
        visual_embeds = self.abstractor(
            visual_features,
            grid_thw=image_metas["vision_grid_thw"],
        )["last_hidden_state"]
        return visual_embeds

    def forward_and_project_vision(
        self,
        pixel_values: torch.Tensor,
        image_metas: dict | None = None,
    ) -> torch.Tensor:
        assert pixel_values is not None
        visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
        visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
        return visual_embeds

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input)

    def forward(
        self,
736
        input_ids: torch.Tensor | None,
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ):
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)