keye.py 59.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
4
from abc import abstractmethod
5
6
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
7
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
8
9
10
11
12
13
14
15

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature
16
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
17
18
from transformers.utils import torch_int

19
from vllm.attention.backends.registry import AttentionBackendEnum
20
21
22
from vllm.attention.layer import (
    maybe_get_vit_flash_attn_backend,
)
23
from vllm.config import VllmConfig
24
from vllm.config.multimodal import BaseDummyOptions
25
26
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
32
33
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
34
35
36
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
37
from vllm.model_executor.models.module_mapping import MultiModelKeys
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
40
41
42
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
43
    MultiModalFeatureSpec,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
61
from vllm.multimodal.profiling import BaseDummyInputsBuilder
62
from vllm.platforms import current_platform
63
from vllm.sequence import IntermediateTensors
64
from vllm.utils.tensor_schema import TensorSchema, TensorShape
65

66
67
68
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
69
    SupportsMRoPE,
70
71
72
    SupportsMultiModal,
    SupportsPP,
)
73
from .siglip import SiglipMLP
74
75
76
77
78
79
80
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    is_pp_missing_parameter,
    maybe_prefix,
)
81
82
83
84
85
86
87
88
from .vision import get_vit_attn_backend

logger = init_logger(__name__)


def smart_resize(
    height: int,
    width: int,
89
90
91
    factor: int,
    min_pixels: int,
    max_pixels: int,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
):
    if height < factor:
        logger.warning(
            "smart_resize: height=%s < factor=%s, reset height=factor",
            height,
            factor,
        )
        width = round((width * factor) / height)
        height = factor

    if width < factor:
        logger.warning(
            "smart_resize: width=%s < factor=%s, reset width=factor",
            width,
            factor,
        )
        height = round((height * factor) / width)
        width = factor

    if max(height, width) / min(height, width) > 200:
112
113
114
115
        raise ValueError(
            "absolute aspect ratio must be smaller than 200, got "
            "{max(height, width) / min(height, width)}"
        )
116
117
118
119
120
121
122
123
124
125
126
127
128
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = math.floor(height / beta / factor) * factor
        w_bar = math.floor(width / beta / factor) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar


129
class KeyeImagePixelInputs(TensorSchema):
130
    """
131
    Dimensions:
132
        - bnp: Batch size * Number of patches
133
134
        - c: Number of channels
        - ps: Patch size
135
136
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
137
    """
138

139
    type: Literal["pixel_values"]
140
    pixel_values: Annotated[
141
142
        torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
    ]
143
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
144
145


146
class KeyeImageEmbeddingInputs(TensorSchema):
147
    """
148
149
    Dimensions:
        - nf: Number of image features
150
        - hs: Hidden size (must match the hidden size of language model
151
152
153
          backbone)
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
154
    """
155

156
157
158
    type: Literal["image_embeds"]
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
159
160


161
KeyeImageInputs: TypeAlias = KeyeImagePixelInputs | KeyeImageEmbeddingInputs
162
163


164
class KeyeVideoPixelInputs(TensorSchema):
165
    """
166
    Dimensions:
167
        - bnp: Batch size * Number of patches
168
169
170
        - c: Number of channels
        - ps: Patch size
        - ni: Number of images
171
        - g: Grid dimensions (3 for t, h, w)
172
    """
173

174
    type: Literal["pixel_values_videos"]
175
    pixel_values_videos: Annotated[
176
177
        torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
    ]
178
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
179
180


181
class KeyeVideoEmbeddingInputs(TensorSchema):
182
    """
183
184
    Dimensions:
        - nf: Number of video features
185
        - hs: Hidden size (must match the hidden size of language model
186
187
188
          backbone)
        - nv: Number of videos
        - g: Grid dimensions (3 for t, h, w)
189
    """
190

191
192
193
    type: Literal["video_embeds"]
    video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
194
195


196
KeyeVideoInputs: TypeAlias = KeyeVideoPixelInputs | KeyeVideoEmbeddingInputs
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


class KeyeVisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

215
        self.num_patches = (self.image_size // self.patch_size) ** 2
216
217
218
        self.num_positions = self.num_patches
        self.cache_position_embedding = dict()
        self.cache_position_count = dict()
219
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)

        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(
        self,
        embeddings: torch.Tensor,
        height: int,
        width: int,
        is_after_patchify: bool = False,
    ) -> torch.Tensor:
        num_positions = self.position_embedding.weight.shape[0]

        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

        dim = embeddings.shape[-1]

        if is_after_patchify:
            new_height = height
            new_width = width
        else:
            new_height = height // self.patch_size
            new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
249
250
251
        patch_pos_embed = patch_pos_embed.reshape(
            1, sqrt_num_positions, sqrt_num_positions, dim
        )
252
253
254
255
256
257
258
259
260
261
262
263
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bilinear",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

264
    def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20):
265
266
267
268
269
270
271
272
273
274
275
276
277
        grid = (h, w)
        if grid in self.cache_position_embedding:
            self.cache_position_count[grid] += 1
            return self.cache_position_embedding[grid]

        if len(self.cache_position_embedding) >= max_cache:
            min_hit_grid = min(
                self.cache_position_count,
                key=self.cache_position_count.get,
            )
            self.cache_position_count.pop(min_hit_grid)
            self.cache_position_embedding.pop(min_hit_grid)

278
        position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
279
280
281
282
283
284
285
        self.cache_position_count[grid] = 1
        self.cache_position_embedding[grid] = position_embedding
        return position_embedding

    def forward(
        self,
        pixel_values: torch.FloatTensor,
286
287
288
        position_ids: torch.Tensor | None = None,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        interpolate_pos_encoding=False,
    ) -> torch.Tensor:
        if pixel_values.dim() == 4:
            pixel_values = pixel_values.unsqueeze(0)
        if pixel_values.dim() == 5:
            if position_ids is None:
                raise ValueError(
                    "position_ids cannot be None when pixel_values.dim() is 5."
                )
            (
                batch_size,
                squence_len,
                channel,
                height,
                width,
            ) = pixel_values.shape
            target_dtype = self.patch_embedding.weight.dtype
            pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
307
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
308
309
310
311
312
313
314
315
316
            embeddings = patch_embeds.flatten(-2).squeeze(-1)

            if interpolate_pos_encoding and image_grid_thw is not None:
                start = 0
                tmp_embeddings = list()
                for image_grid in image_grid_thw:
                    t, h, w = image_grid
                    end = start + t * h * w
                    image_embeddings = embeddings[start:end, :]
317
318
319
320
321
                    position_embedding = (
                        self.interpolate_pos_encoding(image_embeddings, h, w, True)
                        .squeeze(0)
                        .repeat(t, 1)
                    )
322
323
324
325
326
                    image_embeddings = image_embeddings + position_embedding
                    tmp_embeddings.append(image_embeddings)
                    start = end
                embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
            else:
327
                embeddings = embeddings + self.packing_position_embedding(position_ids)
328
329
            return embeddings
        else:
330
331
332
333
            raise ValueError(
                "Unsupported pixel_values dimension:"
                f" {pixel_values.dim()}. Expected 4 or 5."
            )
334
335
336
337
338
339
340
341
342
343
344


def apply_rotary_pos_emb_flashatt(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()

345
346
347
348
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
    elif current_platform.is_rocm():
        from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
349
350
351
352
353
354
355
356
357
358
359
360
361

    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
    return q_embed, k_embed


class KeyeSiglipAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You
    Need' paper."""

    def __init__(
        self,
        config: PretrainedConfig,
362
        quant_config: QuantizationConfig | None = None,
363
        prefix: str = "",
364
        attn_backend_override: AttentionBackendEnum | None = None,
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    ):
        super().__init__()
        self.config = config

        hidden_size = config.hidden_size
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_attention_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        # Detect attention implementation.
403
        self.attn_backend = get_vit_attn_backend(
404
405
406
            head_size=self.head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
407
        )
408

409
410
411
412
413
414
415
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
                self.attn_backend,
                use_upstream_fa=False,
                attn_backend_override=attn_backend_override,
            )
        )
416

417
        if self.attn_backend not in {
418
419
420
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
421
        }:
422
            raise RuntimeError(
423
424
                f"Keye-VL does not support {self.attn_backend} backend now."
            )
425

426
        self.is_flash_attn_backend = self.attn_backend in {
427
428
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
429
430
        }

431
432
433
    def forward(
        self,
        hidden_states: torch.Tensor,
434
435
436
437
        attention_mask: torch.Tensor | None = None,
        output_attentions: bool | None = False,
        cu_seqlens: list[torch.Tensor] | None = None,
        rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split(
            [self.q_size, self.kv_size, self.kv_size],
            dim=-1,
        )

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        batch_size = q.shape[0]

        if rope_emb is None:
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
        else:
            if cu_seqlens is None:
463
                raise ValueError("cu_seqlens cannot be None when rope_emb is not None.")
464
465
466
467
468
469
470
471
472
473
474
475
476
477
            cos, sin = rope_emb
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )

478
        if self.is_flash_attn_backend:
479
480
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

481
            output = self.flash_attn_varlen_func(
482
483
484
485
486
487
488
489
490
491
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                causal=False,
                softmax_scale=self.scale,
            )
492
            context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
493
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
494
495
496
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

497
498
499
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
500
501

            context_layer = xops.memory_efficient_attention_forward(
502
503
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
504

505
        context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
506
507
508
509
510
511
512
513
514
515
516
517
518

        output, _ = self.out_proj(context_layer)
        return output


class SigLIPRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.rope_init()

    def rope_init(self):
519
520
521
        inv_freq = 1.0 / (
            self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
        )
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(
            seqlen,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class KeyeSiglipEncoderLayer(nn.Module):
    def __init__(
        self,
537
538
        config: PretrainedConfig,
        quant_config: QuantizationConfig | None = None,
539
        prefix: str = "",
540
        attn_backend_override: AttentionBackendEnum | None = None,
541
542
543
    ):
        super().__init__()
        self.embed_dim = config.hidden_size
544
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
545
546
547
548
        self.self_attn = KeyeSiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
549
            attn_backend_override=attn_backend_override,
550
        )
551
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
552
553
554
555
556
557
558
559
560
561
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
562
563
564
        output_attentions: bool | None = False,
        cu_seqlens: list[torch.Tensor] | None = None,
        rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    ) -> tuple[torch.FloatTensor]:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            cu_seqlens=cu_seqlens,
            rope_emb=rope_emb,
        )

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states

        return hidden_states


class KeyeSiglipEncoder(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
592
        quant_config: QuantizationConfig | None = None,
593
        prefix: str = "",
594
        attn_backend_override: AttentionBackendEnum | None = None,
595
596
597
598
599
600
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
601
602
603
604
605
606
        self.layers = nn.ModuleList(
            [
                KeyeSiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
607
                    attn_backend_override=attn_backend_override,
608
609
610
611
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

    @staticmethod
    def flatten_list(image_grid_thw):
        tmp_image_grid_thw = list()
        for image_grid in image_grid_thw:
            if isinstance(image_grid, list):
                tmp_image_grid_thw.extend(image_grid)
            else:
                tmp_image_grid_thw.append(image_grid)
        return tmp_image_grid_thw

    def forward(
        self,
        inputs_embeds,
627
628
629
630
631
632
633
634
635
636
        attention_mask: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        height_position_ids: torch.Tensor | None = None,
        width_position_ids: torch.Tensor | None = None,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
637
638
639
640
641
642
643
644
645
646
647
        vision_or_text: str = "vision",
    ) -> BaseModelOutput:
        device = inputs_embeds.device
        hidden_states = inputs_embeds
        if use_rope is True:
            flatten_image_grid_thw = self.flatten_list(image_grid_thw)

            if width_position_ids is None or height_position_ids is None:
                split_hids = list()
                split_wids = list()
                for t, h, w in flatten_image_grid_thw:
648
                    image_pids = torch.arange(t * h * w, device=device) % (h * w)
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
                    sample_hids = image_pids // w
                    sample_wids = image_pids % w
                    split_hids.append(sample_hids)
                    split_wids.append(sample_wids)
                width_position_ids = torch.concat(split_wids, dim=0)
                height_position_ids = torch.concat(split_hids, dim=0)

            pids = torch.stack(
                [height_position_ids, width_position_ids],
                dim=-1,
            )
            max_grid_size = pids.max() + 1
            rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
            rope_emb = rope_emb_max_grid[pids].flatten(1)
            rope_emb = rope_emb.repeat(1, 2)
            rope_emb = (rope_emb.cos(), rope_emb.sin())
        else:
            rope_emb = None

        attn_cu_seqlens = cu_seqlens
        hidden_states = inputs_embeds
        assert attention_mask is None

        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
                output_attentions=output_attentions,
                cu_seqlens=attn_cu_seqlens,
                rope_emb=rope_emb,
            )
        return hidden_states


class KeyeSiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
687
        quant_config: QuantizationConfig | None = None,
688
        prefix: str = "",
689
        attn_backend_override: AttentionBackendEnum | None = None,
690
691
692
693
694
695
696
697
698
699
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = KeyeVisionEmbeddings(config)
        self.encoder = KeyeSiglipEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
700
            attn_backend_override=attn_backend_override,
701
        )
702
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
703
704
705
706

    def forward(
        self,
        pixel_values,
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        interpolate_pos_encoding: bool | None = False,
        attention_mask: torch.Tensor | None = None,
        sample_indices: torch.Tensor | None = None,
        image_indices: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
        height_position_ids: torch.Tensor | None = None,
        width_position_ids: torch.Tensor | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        padding_mask: torch.Tensor | None = None,
        vision_return_embed_list: bool | None = False,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        return_pooler_output: bool | None = True,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
    ) -> BaseModelOutputWithPooling:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            image_grid_thw=image_grid_thw,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            attention_mask=attention_mask,
            cu_seqlens=cu_seqlens,
            image_grid_thw=image_grid_thw,
            use_rope=use_rope,
            height_position_ids=height_position_ids,
            width_position_ids=width_position_ids,
            window_size=window_size,
            vision_or_text="vision",
        )

        last_hidden_state = self.post_layernorm(last_hidden_state)

        sample_hidden_state = list()
        if cu_seqlens is None:
750
751
752
753
            raise ValueError(
                "cu_seqlens cannot be None for "
                "SiglipVisionTransformer output processing."
            )
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
        for i in range(cu_seqlens.shape[0] - 1):
            start = cu_seqlens[i]
            end = cu_seqlens[i + 1]
            tensor = last_hidden_state[:, start:end, :].squeeze(0)
            sample_hidden_state.append(tensor)

        return sample_hidden_state


class KeyeSiglipVisionModel(nn.Module):
    config_class = PretrainedConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: PretrainedConfig,
770
        quant_config: QuantizationConfig | None = None,
771
        prefix: str = "",
772
        attn_backend_override: AttentionBackendEnum | None = None,
773
774
775
776
777
778
779
    ):
        super().__init__()

        self.vision_model = KeyeSiglipVisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
780
            attn_backend_override=attn_backend_override,
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
        )
        self.quant_config = quant_config

    @property
    def dtype(self) -> torch.dtype:
        return self.vision_model.embeddings.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.vision_model.embeddings.patch_embedding.weight.device

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values,
798
799
800
        sample_indices: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
801
        interpolate_pos_encoding: bool = False,
802
803
804
805
806
807
808
809
        position_ids: torch.Tensor | None = None,
        vision_return_embed_list: bool | None = False,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        return_pooler_output: bool | None = True,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    ) -> BaseModelOutputWithPooling:
        return self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            vision_return_embed_list=vision_return_embed_list,
            image_grid_thw=image_grid_thw,
            sample_indices=sample_indices,
            cu_seqlens=cu_seqlens,
            return_pooler_output=return_pooler_output,
            use_rope=use_rope,
            window_size=window_size,
        )

826
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "head.attention" in name or "head.layernorm" in name:
                continue
            if "head.mlp" in name or "head.probe" in name:
                continue
            if self.quant_config is not None and (
842
843
                scale_name := self.quant_config.get_cache_scale(name)
            ):
844
845
846
847
848
849
                param = params_dict[scale_name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
850
851
852
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
853
854
855
856
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (
857
858
859
                param_name,
                weight_name,
                shard_id,
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
            ) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Projector(nn.Module):
    def __init__(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
896
        quant_config: QuantizationConfig | None = None,
897
898
899
900
901
902
903
        prefix: str = "",
    ):
        super().__init__()
        self.text_config = text_config
        self.vision_config = vision_config
        self.merge_kernel_size = (2, 2)

904
905
906
907
908
        self.hidden_size = (
            self.vision_config.hidden_size
            * self.merge_kernel_size[0]
            * self.merge_kernel_size[1]
        )
909

910
        self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05)
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
        self.act = GELUActivation()

        self.linear_1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
        self.linear_2 = RowParallelLinear(
            self.hidden_size,
            self.text_config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )

    def forward(
        self,
930
        image_features: torch.Tensor | list[torch.Tensor],
931
        image_grid_thw: list[tuple[int, int, int]],
932
    ) -> torch.Tensor | list[torch.Tensor]:
933
934
935
        m1, m2 = self.merge_kernel_size
        if isinstance(image_features, (list, tuple)):
            processed_features = list()
936
            for image_feature, image_grid in zip(image_features, image_grid_thw):
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
                image_feature = self.pre_norm(image_feature)
                t, h, w = image_grid

                image_feature = rearrange(
                    image_feature,
                    "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
                    t=t,
                    h=h // m1,
                    p1=m1,
                    w=w // m2,
                    p2=m2,
                )
                hidden_states, _ = self.linear_1(image_feature)
                hidden_states = self.act(hidden_states)
                hidden_states, _ = self.linear_2(hidden_states)
                processed_features.append(hidden_states)

            return processed_features

        dims = image_features.shape[:-1]
        dim = image_features.shape[-1]
        image_features = image_features.view(np.prod(dims), dim)
959
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
960
961
962
963
964
965
966
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)

        return hidden_states.view(*dims, -1)


967
968
969
def _keye_field_config(
    hf_inputs: Mapping[str, torch.Tensor],
):
970
971
972
973
974
975
976
    image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
    image_grid_sizes = image_grid_thw.prod(-1)

    video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
    video_grid_sizes = video_grid_thw.prod(-1)

    return dict(
977
978
        pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
        image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
979
980
        image_grid_thw=MultiModalFieldConfig.batched("image"),
        pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
981
982
983
            "video", video_grid_sizes
        ),
        video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes),
984
985
986
987
988
989
990
        video_grid_thw=MultiModalFieldConfig.batched("video"),
    )


class KeyeMultiModalDataParser(MultiModalDataParser):
    def _parse_image_data(
        self,
991
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={
                    "image_embeds",
                    "image_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
1008
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={
                    "video_embeds",
                    "video_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_video_data(data)


class KeyeProcessingInfo(BaseProcessingInfo):
1025
    def get_max_image_size(self) -> int:
1026
        return 9999999  # _MAX_IMAGE_SIZE
1027
1028

    def get_max_frame_per_video(self) -> int:
1029
        return 16  # _MAX_FRAMES_PER_VIDEO
1030

1031
1032
    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor
1033

1034
1035
    def get_supported_mm_limits(
        self,
1036
    ) -> Mapping[str, int | None]:
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        return {"image": None, "video": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {
            "image": self.get_max_image_tokens(),
            "video": self.get_max_video_tokens(seq_len),
        }

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
        image_processor,
    ) -> tuple[ImageSize, int]:
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = 1

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
1075
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
1076
        else:
1077
            preprocessed_size = ImageSize(width=image_width, height=image_height)
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        image_processor,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            image_processor=image_processor,
        )
        return num_image_tokens

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
        image_processor,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            image_processor=image_processor,
        )
        return num_video_tokens

1120
1121
1122
    def get_image_size_with_most_features(
        self,
    ) -> ImageSize:
1123
        max_image_size, _ = self._get_vision_info(
1124
1125
            image_width=self.get_max_image_size(),
            image_height=self.get_max_image_size(),
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
            image_processor=None,
        )
        return max_image_size

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            image_processor=None,
        )

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )

            if next_max_tokens > max_tokens:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.get_limit_per_prompt("image")
        max_videos = mm_config.get_limit_per_prompt("video")

        max_image_tokens = self.get_max_image_tokens() * max_images
1166
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
1167
1168
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1),
1169
            self.get_max_frame_per_video(),
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
        )

        return max(max_frames_per_video, 1)

    def get_max_video_tokens(self, seq_len: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=self.get_num_frames_with_most_features(seq_len),
            image_processor=None,
        )


1185
1186
1187
1188
_I = TypeVar("_I", bound=KeyeProcessingInfo)


class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_processor = self.info.get_hf_processor()
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1203
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1204
1205
1206
1207
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1208
1209
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(seq_len)
1210

1211
1212
1213
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1214
        mm_data = {
1215
            "image": self._get_dummy_images(
1216
1217
1218
                width=target_width,
                height=target_height,
                num_images=num_images,
1219
                overrides=image_overrides,
1220
            ),
1221
            "video": self._get_dummy_videos(
1222
1223
1224
1225
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
1226
                overrides=video_overrides,
1227
1228
1229
1230
1231
1232
            ),
        }

        return mm_data


1233
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ...
1234
1235


1236
1237
1238
1239
1240
1241
1242
1243
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        return KeyeMultiModalDataParser()

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1244
        out_mm_kwargs: MultiModalKwargsItems,
1245
1246
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1247
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        placeholder = {
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
        }

        merge_length = image_processor.merge_size**2

        def get_replacement_keye(item_idx: int, modality: str):
1259
1260
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=[placeholder[modality]],
                replacement=partial(get_replacement_keye, modality=modality),
1271
1272
            )
            for modality in ("image", "video")
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _keye_field_config(hf_inputs)


1283
class BaseKeyeModule(nn.Module):
1284
1285
    merge_by_field_config = True

1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

1298
1299
1300
1301
1302
1303
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
    )
1304

1305
    @classmethod
1306
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1307
1308
1309
1310
1311
1312
1313
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1314
1315
1316
1317
1318
1319
1320
1321
1322
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: PretrainedConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

1323
1324
1325
1326
1327
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
1328
1329
        self.visual = KeyeSiglipVisionModel(
            config.vision_config,
1330
            quant_config=quant_config,
1331
            prefix=maybe_prefix(prefix, "visual"),
1332
            attn_backend_override=attn_backend_override,
1333
        )
1334
1335

        self.mlp_AR = self._build_projector(
1336
1337
            config,
            config.vision_config,
1338
            quant_config=quant_config,
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
            prefix=maybe_prefix(prefix, "mlp_AR"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen3ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
1349
1350
            self.language_model.make_empty_intermediate_tensors
        )
1351

1352
    @abstractmethod
1353
1354
1355
1356
    def _build_projector(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
1357
        quant_config: QuantizationConfig | None = None,
1358
1359
        prefix: str = "",
    ) -> nn.Module:
1360
        raise ValueError("Need projector")
1361

1362
    def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]:
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
        siglip_position_ids = list()
        image_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        image_grid_thw = image_input["image_grid_thw"]
        assert image_grid_thw.ndim == 2

        for idx, thaw in enumerate(image_grid_thw):
            thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
            numel = np.prod(thw_tuple)
            image_grid_hws.append(thw_tuple)
            image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(image_position_ids)
1377
            sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
1378
1379
1380
1381
            cu_seqlens.append(cu_seqlens[-1] + numel)

        if image_input["type"] == "image_embeds":
            raise ValueError(
1382
1383
                "Image embeddings are not supported for this processing path."
            )
1384
1385
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1386
1387
1388
            siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
                pixel_values.device
            )
1389
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
1390
1391
1392
                pixel_values.device
            )
            sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device)
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407

            image_embeds = self.visual(
                pixel_values=pixel_values,
                image_grid_thw=image_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=False,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
            image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
            return image_embeds

1408
1409
1410
1411
    def _process_video_embeds(
        self,
        video_type: Literal["video_embeds", "pixel_values_videos"],
        video_grid_thw: list[torch.Tensor],
1412
1413
        pixel_values_videos: torch.Tensor | None = None,
    ) -> torch.Tensor | list[torch.Tensor]:
1414
1415
1416
1417
1418
1419
        siglip_position_ids = list()
        video_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        assert video_grid_thw.ndim == 2
1420
1421
        for idx, sub_thw in enumerate(video_grid_thw):
            thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
1422
1423
1424
1425
1426
            numel = np.prod(thw_tuple)

            video_grid_hws.append(thw_tuple)
            video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(video_position_ids)
1427
            sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
1428
1429
            cu_seqlens.append(cu_seqlens[-1] + numel)

1430
        if video_type == "video_embeds":
1431
            raise ValueError(
1432
1433
                "Video embeddings are not supported for this processing path."
            )
1434
        else:
1435
            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1436
            siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
1437
1438
                pixel_values_videos.device
            )
1439
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
1440
1441
1442
1443
1444
                pixel_values_videos.device
            )
            sample_indices = torch.concat(sample_indices, dim=0).to(
                pixel_values_videos.device
            )
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456

            video_embeds = self.visual(
                pixel_values=pixel_values_videos,
                image_grid_thw=video_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=True,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
1457
            video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
1458
1459
1460
1461
1462
1463
            return video_embeds

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        for input_key in kwargs:
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1474
1475
1476

        return modalities

1477
1478
1479
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1480
    def get_multimodal_embeddings(
1481
        self, **kwargs: object
1482
    ) -> MultiModalEmbeddings | None:
1483
1484
1485
1486
1487
1488
1489
1490
1491
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1492
1493
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1494
1495
1496
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1497
                multimodal_embeddings += tuple(video_embeddings)
1498
1499
1500
1501
1502
1503
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1504
1505
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1506
        **kwargs: object,
1507
    ) -> torch.Tensor | IntermediateTensors:
1508
        """Run forward pass for Keye-VL.
1509
1510
1511
1512
1513
1514
1515
1516

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1517
1518
1519
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
1530

1531
1532
1533
1534
1535
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1536
    ) -> torch.Tensor | None:
1537
        return self.language_model.compute_logits(hidden_states)
1538

1539
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1540
1541
1542
1543
1544
1545
1546
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix in multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1547
1548
            connector="mlp_AR.",
            tower_model="visual.",
1549
        )
1550
1551
1552
1553
1554
1555
1556


@MULTIMODAL_REGISTRY.register_processor(
    KeyeMultiModalProcessor,
    info=KeyeProcessingInfo,
    dummy_inputs=KeyeDummyInputsBuilder,
)
1557
class KeyeForConditionalGeneration(
1558
    BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1559
1560
1561
1562
1563
):
    def _build_projector(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
1564
        quant_config: QuantizationConfig | None = None,
1565
1566
        prefix: str = "",
    ) -> nn.Module:
1567
1568
1569
        return Projector(text_config, vision_config, quant_config, prefix)

    def _parse_and_validate_image_input(
1570
        self, **kwargs: object
1571
    ) -> KeyeImageInputs | None:
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return KeyeImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return KeyeImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
1594
        self, **kwargs: object
1595
    ) -> KeyeVideoInputs | None:
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return KeyeVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            return KeyeVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_video_input(
1618
1619
        self, video_input: KeyeVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1620
1621
1622
1623
1624
        video_type = video_input["type"]
        video_grid_thw = video_input["video_grid_thw"]
        pixel_values_videos = video_input.get("pixel_values_videos", None)

        return tuple(
1625
1626
            self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
        )
1627
1628
1629
1630

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1631
        mm_features: list[MultiModalFeatureSpec],
1632
    ) -> tuple[torch.Tensor, int]:
1633
1634
1635
1636
1637
1638
1639
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
        if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
            video_grid_thw = video_grid_thw[0]

        def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
            """
            Split grid_thw along the t dimension.

            Args:
                grid_thw: shape [N, 3] tensor or nested list of [t, h, w].

            Returns:
                List of [1, h, w] rows, repeated t times for each original row.
            """

            if isinstance(grid_thw, list):
                grid_thw = torch.tensor(grid_thw, dtype=torch.long)

            if grid_thw.numel() == 0:
                return []

            t, hw = grid_thw[:, 0], grid_thw[:, 1:]
            ones = torch.ones_like(hw[:, :1])  # [N,1]
            out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
            return out.tolist()

        video_grid_thw = split_thw(video_grid_thw)

1667
        hf_config = self.config
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size

        image_nums = len(image_grid_thw)
        frame_nums = len(video_grid_thw)
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_frames = image_nums, frame_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + frame_nums):
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_frames > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1

            if ed_image < ed_video:
1697
                t, h, w = image_grid_thw[image_index]
1698
1699
1700
1701
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1702
                t, h, w = video_grid_thw[video_index]
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
                video_index += 1
                remain_frames -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )

            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                )
                .long()
                .flatten()
            )

            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

        return llm_positions, mrope_position_delta