keye.py 60.3 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
4
from abc import abstractmethod
5
6
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
7
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
8
9
10
11

import numpy as np
import torch
import torch.nn as nn
12
import torch.nn.functional as F
13
14
15
16
from einops import rearrange
from transformers import PretrainedConfig
from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature
17
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
18
19
from transformers.utils import torch_int

20
from vllm.attention.backends.registry import AttentionBackendEnum
21
22
23
from vllm.attention.layer import (
    maybe_get_vit_flash_attn_backend,
)
24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
27
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
28
from vllm.model_executor.layers.conv import Conv2dLayer
29
30
31
32
33
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
34
35
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
36
37
38
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
39
from vllm.model_executor.models.module_mapping import MultiModelKeys
40
from vllm.multimodal import MULTIMODAL_REGISTRY
41
42
43
44
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
45
    MultiModalFeatureSpec,
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ImageSize,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
63
from vllm.multimodal.profiling import BaseDummyInputsBuilder
64
from vllm.platforms import current_platform
65
from vllm.sequence import IntermediateTensors
66
from vllm.utils.tensor_schema import TensorSchema, TensorShape
67

68
69
70
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
71
    SupportsMRoPE,
72
73
74
    SupportsMultiModal,
    SupportsPP,
)
75
from .siglip import SiglipMLP
76
77
78
79
80
81
82
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    is_pp_missing_parameter,
    maybe_prefix,
)
83
84
85
86
87
88
89
90
from .vision import get_vit_attn_backend

logger = init_logger(__name__)


def smart_resize(
    height: int,
    width: int,
91
92
93
    factor: int,
    min_pixels: int,
    max_pixels: int,
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
):
    if height < factor:
        logger.warning(
            "smart_resize: height=%s < factor=%s, reset height=factor",
            height,
            factor,
        )
        width = round((width * factor) / height)
        height = factor

    if width < factor:
        logger.warning(
            "smart_resize: width=%s < factor=%s, reset width=factor",
            width,
            factor,
        )
        height = round((height * factor) / width)
        width = factor

    if max(height, width) / min(height, width) > 200:
114
115
116
117
        raise ValueError(
            "absolute aspect ratio must be smaller than 200, got "
            "{max(height, width) / min(height, width)}"
        )
118
119
120
121
122
123
124
125
126
127
128
129
130
    h_bar = round(height / factor) * factor
    w_bar = round(width / factor) * factor
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = math.floor(height / beta / factor) * factor
        w_bar = math.floor(width / beta / factor) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor
    return h_bar, w_bar


131
class KeyeImagePixelInputs(TensorSchema):
132
    """
133
    Dimensions:
134
        - bnp: Batch size * Number of patches
135
136
        - c: Number of channels
        - ps: Patch size
137
138
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
139
    """
140

141
    type: Literal["pixel_values"]
142
    pixel_values: Annotated[
143
144
        torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
    ]
145
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
146
147


148
class KeyeImageEmbeddingInputs(TensorSchema):
149
    """
150
151
    Dimensions:
        - nf: Number of image features
152
        - hs: Hidden size (must match the hidden size of language model
153
154
155
          backbone)
        - ni: Number of images
        - g: Grid dimensions (3 for t, h, w)
156
    """
157

158
159
160
    type: Literal["image_embeds"]
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
161
162


163
KeyeImageInputs: TypeAlias = KeyeImagePixelInputs | KeyeImageEmbeddingInputs
164
165


166
class KeyeVideoPixelInputs(TensorSchema):
167
    """
168
    Dimensions:
169
        - bnp: Batch size * Number of patches
170
171
172
        - c: Number of channels
        - ps: Patch size
        - ni: Number of images
173
        - g: Grid dimensions (3 for t, h, w)
174
    """
175

176
    type: Literal["pixel_values_videos"]
177
    pixel_values_videos: Annotated[
178
179
        torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})
    ]
180
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
181
182


183
class KeyeVideoEmbeddingInputs(TensorSchema):
184
    """
185
186
    Dimensions:
        - nf: Number of video features
187
        - hs: Hidden size (must match the hidden size of language model
188
189
190
          backbone)
        - nv: Number of videos
        - g: Grid dimensions (3 for t, h, w)
191
    """
192

193
194
195
    type: Literal["video_embeds"]
    video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
196
197


198
KeyeVideoInputs: TypeAlias = KeyeVideoPixelInputs | KeyeVideoEmbeddingInputs
199
200
201
202
203
204
205
206
207
208


class KeyeVisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

209
        self.patch_embedding = Conv2dLayer(
210
211
212
213
214
215
216
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

217
        self.num_patches = (self.image_size // self.patch_size) ** 2
218
219
220
        self.num_positions = self.num_patches
        self.cache_position_embedding = dict()
        self.cache_position_count = dict()
221
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)

        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(
        self,
        embeddings: torch.Tensor,
        height: int,
        width: int,
        is_after_patchify: bool = False,
    ) -> torch.Tensor:
        num_positions = self.position_embedding.weight.shape[0]

        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)

        dim = embeddings.shape[-1]

        if is_after_patchify:
            new_height = height
            new_width = width
        else:
            new_height = height // self.patch_size
            new_width = width // self.patch_size

        sqrt_num_positions = torch_int(num_positions**0.5)
251
252
253
        patch_pos_embed = patch_pos_embed.reshape(
            1, sqrt_num_positions, sqrt_num_positions, dim
        )
254
255
256
257
258
259
260
261
262
263
264
265
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            size=(new_height, new_width),
            mode="bilinear",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

266
    def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20):
267
268
269
270
271
272
273
274
275
276
277
278
279
        grid = (h, w)
        if grid in self.cache_position_embedding:
            self.cache_position_count[grid] += 1
            return self.cache_position_embedding[grid]

        if len(self.cache_position_embedding) >= max_cache:
            min_hit_grid = min(
                self.cache_position_count,
                key=self.cache_position_count.get,
            )
            self.cache_position_count.pop(min_hit_grid)
            self.cache_position_embedding.pop(min_hit_grid)

280
        position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
281
282
283
284
285
286
287
        self.cache_position_count[grid] = 1
        self.cache_position_embedding[grid] = position_embedding
        return position_embedding

    def forward(
        self,
        pixel_values: torch.FloatTensor,
288
289
290
        position_ids: torch.Tensor | None = None,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        interpolate_pos_encoding=False,
    ) -> torch.Tensor:
        if pixel_values.dim() == 4:
            pixel_values = pixel_values.unsqueeze(0)
        if pixel_values.dim() == 5:
            if position_ids is None:
                raise ValueError(
                    "position_ids cannot be None when pixel_values.dim() is 5."
                )
            (
                batch_size,
                squence_len,
                channel,
                height,
                width,
            ) = pixel_values.shape
            target_dtype = self.patch_embedding.weight.dtype
            pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
309
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
310
311
312
313
314
315
316
317
318
            embeddings = patch_embeds.flatten(-2).squeeze(-1)

            if interpolate_pos_encoding and image_grid_thw is not None:
                start = 0
                tmp_embeddings = list()
                for image_grid in image_grid_thw:
                    t, h, w = image_grid
                    end = start + t * h * w
                    image_embeddings = embeddings[start:end, :]
319
320
321
322
323
                    position_embedding = (
                        self.interpolate_pos_encoding(image_embeddings, h, w, True)
                        .squeeze(0)
                        .repeat(t, 1)
                    )
324
325
326
327
328
                    image_embeddings = image_embeddings + position_embedding
                    tmp_embeddings.append(image_embeddings)
                    start = end
                embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
            else:
329
                embeddings = embeddings + self.packing_position_embedding(position_ids)
330
331
            return embeddings
        else:
332
333
334
335
            raise ValueError(
                "Unsupported pixel_values dimension:"
                f" {pixel_values.dim()}. Expected 4 or 5."
            )
336
337
338
339
340
341
342
343
344
345
346


def apply_rotary_pos_emb_flashatt(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()

347
348
349
350
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
    elif current_platform.is_rocm():
        from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
351
352
353
354
355
356
357
    else:
        # For other platforms, use PyTorch fallback
        from vllm.model_executor.layers.rotary_embedding.common import (
            apply_rotary_emb_torch,
        )

        apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
358
359
360
361
362
363
364
365
366
367
368
369
370

    q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
    return q_embed, k_embed


class KeyeSiglipAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You
    Need' paper."""

    def __init__(
        self,
        config: PretrainedConfig,
371
        quant_config: QuantizationConfig | None = None,
372
        prefix: str = "",
373
        attn_backend_override: AttentionBackendEnum | None = None,
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    ):
        super().__init__()
        self.config = config

        hidden_size = config.hidden_size
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_attention_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        # Detect attention implementation.
412
        self.attn_backend = get_vit_attn_backend(
413
414
415
            head_size=self.head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
416
        )
417

418
419
420
421
422
423
424
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
                self.attn_backend,
                use_upstream_fa=False,
                attn_backend_override=attn_backend_override,
            )
        )
425

426
        if self.attn_backend not in {
427
            AttentionBackendEnum.FLASH_ATTN,
428
            AttentionBackendEnum.TORCH_SDPA,
429
            AttentionBackendEnum.ROCM_AITER_FA,
430
        }:
431
            raise RuntimeError(
432
433
                f"Keye-VL does not support {self.attn_backend} backend now."
            )
434

435
        self.is_flash_attn_backend = self.attn_backend in {
436
437
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
438
439
        }

440
441
442
    def forward(
        self,
        hidden_states: torch.Tensor,
443
444
445
446
        attention_mask: torch.Tensor | None = None,
        output_attentions: bool | None = False,
        cu_seqlens: list[torch.Tensor] | None = None,
        rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split(
            [self.q_size, self.kv_size, self.kv_size],
            dim=-1,
        )

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        batch_size = q.shape[0]

        if rope_emb is None:
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
        else:
            if cu_seqlens is None:
471
                raise ValueError("cu_seqlens cannot be None when rope_emb is not None.")
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            cos, sin = rope_emb
            q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
            k = k.view(
                *k.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )
            q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
            v = v.view(
                *v.shape[:-1],
                self.num_kv_heads,
                self.head_dim,
            )

486
        if self.is_flash_attn_backend:
487
488
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

489
            output = self.flash_attn_varlen_func(
490
491
492
493
494
495
496
497
498
499
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                causal=False,
                softmax_scale=self.scale,
            )
500
            context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i)
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
516

517
        context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous()
518
519
520
521
522
523
524
525
526
527
528
529
530

        output, _ = self.out_proj(context_layer)
        return output


class SigLIPRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.rope_init()

    def rope_init(self):
531
532
533
        inv_freq = 1.0 / (
            self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
        )
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(
            seqlen,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class KeyeSiglipEncoderLayer(nn.Module):
    def __init__(
        self,
549
550
        config: PretrainedConfig,
        quant_config: QuantizationConfig | None = None,
551
        prefix: str = "",
552
        attn_backend_override: AttentionBackendEnum | None = None,
553
554
555
    ):
        super().__init__()
        self.embed_dim = config.hidden_size
556
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
557
558
559
560
        self.self_attn = KeyeSiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
561
            attn_backend_override=attn_backend_override,
562
        )
563
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
564
565
566
567
568
569
570
571
572
573
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
574
575
576
        output_attentions: bool | None = False,
        cu_seqlens: list[torch.Tensor] | None = None,
        rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    ) -> tuple[torch.FloatTensor]:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            cu_seqlens=cu_seqlens,
            rope_emb=rope_emb,
        )

        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states

        return hidden_states


class KeyeSiglipEncoder(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
604
        quant_config: QuantizationConfig | None = None,
605
        prefix: str = "",
606
        attn_backend_override: AttentionBackendEnum | None = None,
607
608
609
610
611
612
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
613
614
615
616
617
618
        self.layers = nn.ModuleList(
            [
                KeyeSiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
619
                    attn_backend_override=attn_backend_override,
620
621
622
623
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)

    @staticmethod
    def flatten_list(image_grid_thw):
        tmp_image_grid_thw = list()
        for image_grid in image_grid_thw:
            if isinstance(image_grid, list):
                tmp_image_grid_thw.extend(image_grid)
            else:
                tmp_image_grid_thw.append(image_grid)
        return tmp_image_grid_thw

    def forward(
        self,
        inputs_embeds,
639
640
641
642
643
644
645
646
647
648
        attention_mask: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        height_position_ids: torch.Tensor | None = None,
        width_position_ids: torch.Tensor | None = None,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
649
650
651
652
653
654
655
656
657
658
659
        vision_or_text: str = "vision",
    ) -> BaseModelOutput:
        device = inputs_embeds.device
        hidden_states = inputs_embeds
        if use_rope is True:
            flatten_image_grid_thw = self.flatten_list(image_grid_thw)

            if width_position_ids is None or height_position_ids is None:
                split_hids = list()
                split_wids = list()
                for t, h, w in flatten_image_grid_thw:
660
                    image_pids = torch.arange(t * h * w, device=device) % (h * w)
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
                    sample_hids = image_pids // w
                    sample_wids = image_pids % w
                    split_hids.append(sample_hids)
                    split_wids.append(sample_wids)
                width_position_ids = torch.concat(split_wids, dim=0)
                height_position_ids = torch.concat(split_hids, dim=0)

            pids = torch.stack(
                [height_position_ids, width_position_ids],
                dim=-1,
            )
            max_grid_size = pids.max() + 1
            rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
            rope_emb = rope_emb_max_grid[pids].flatten(1)
            rope_emb = rope_emb.repeat(1, 2)
            rope_emb = (rope_emb.cos(), rope_emb.sin())
        else:
            rope_emb = None

        attn_cu_seqlens = cu_seqlens
        hidden_states = inputs_embeds
        assert attention_mask is None

        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
                output_attentions=output_attentions,
                cu_seqlens=attn_cu_seqlens,
                rope_emb=rope_emb,
            )
        return hidden_states


class KeyeSiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
699
        quant_config: QuantizationConfig | None = None,
700
        prefix: str = "",
701
        attn_backend_override: AttentionBackendEnum | None = None,
702
703
704
705
706
707
708
709
710
711
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = KeyeVisionEmbeddings(config)
        self.encoder = KeyeSiglipEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
712
            attn_backend_override=attn_backend_override,
713
        )
714
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
715
716
717
718

    def forward(
        self,
        pixel_values,
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        interpolate_pos_encoding: bool | None = False,
        attention_mask: torch.Tensor | None = None,
        sample_indices: torch.Tensor | None = None,
        image_indices: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
        height_position_ids: torch.Tensor | None = None,
        width_position_ids: torch.Tensor | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        padding_mask: torch.Tensor | None = None,
        vision_return_embed_list: bool | None = False,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        return_pooler_output: bool | None = True,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
    ) -> BaseModelOutputWithPooling:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            image_grid_thw=image_grid_thw,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            attention_mask=attention_mask,
            cu_seqlens=cu_seqlens,
            image_grid_thw=image_grid_thw,
            use_rope=use_rope,
            height_position_ids=height_position_ids,
            width_position_ids=width_position_ids,
            window_size=window_size,
            vision_or_text="vision",
        )

        last_hidden_state = self.post_layernorm(last_hidden_state)

        sample_hidden_state = list()
        if cu_seqlens is None:
762
763
764
765
            raise ValueError(
                "cu_seqlens cannot be None for "
                "SiglipVisionTransformer output processing."
            )
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        for i in range(cu_seqlens.shape[0] - 1):
            start = cu_seqlens[i]
            end = cu_seqlens[i + 1]
            tensor = last_hidden_state[:, start:end, :].squeeze(0)
            sample_hidden_state.append(tensor)

        return sample_hidden_state


class KeyeSiglipVisionModel(nn.Module):
    config_class = PretrainedConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: PretrainedConfig,
782
        quant_config: QuantizationConfig | None = None,
783
        prefix: str = "",
784
        attn_backend_override: AttentionBackendEnum | None = None,
785
786
787
788
789
790
791
    ):
        super().__init__()

        self.vision_model = KeyeSiglipVisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
792
            attn_backend_override=attn_backend_override,
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        )
        self.quant_config = quant_config

    @property
    def dtype(self) -> torch.dtype:
        return self.vision_model.embeddings.patch_embedding.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.vision_model.embeddings.patch_embedding.weight.device

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values,
810
811
812
        sample_indices: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
813
        interpolate_pos_encoding: bool = False,
814
815
816
817
818
819
820
821
        position_ids: torch.Tensor | None = None,
        vision_return_embed_list: bool | None = False,
        image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]]
        | None = None,
        cu_seqlens: list[torch.Tensor] | None = None,
        return_pooler_output: bool | None = True,
        use_rope: bool | None = False,
        window_size: bool | None = -1,
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    ) -> BaseModelOutputWithPooling:
        return self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            position_ids=position_ids,
            vision_return_embed_list=vision_return_embed_list,
            image_grid_thw=image_grid_thw,
            sample_indices=sample_indices,
            cu_seqlens=cu_seqlens,
            return_pooler_output=return_pooler_output,
            use_rope=use_rope,
            window_size=window_size,
        )

838
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if "head.attention" in name or "head.layernorm" in name:
                continue
            if "head.mlp" in name or "head.probe" in name:
                continue
            if self.quant_config is not None and (
854
855
                scale_name := self.quant_config.get_cache_scale(name)
            ):
856
857
858
859
860
861
                param = params_dict[scale_name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
862
863
864
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
865
866
867
868
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for (
869
870
871
                param_name,
                weight_name,
                shard_id,
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
            ) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if name.endswith(".bias") and name not in params_dict:
                    continue
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(
                    param,
                    "weight_loader",
                    default_weight_loader,
                )
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Projector(nn.Module):
    def __init__(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
908
        quant_config: QuantizationConfig | None = None,
909
910
911
912
913
914
915
        prefix: str = "",
    ):
        super().__init__()
        self.text_config = text_config
        self.vision_config = vision_config
        self.merge_kernel_size = (2, 2)

916
917
918
919
920
        self.hidden_size = (
            self.vision_config.hidden_size
            * self.merge_kernel_size[0]
            * self.merge_kernel_size[1]
        )
921

922
        self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05)
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        self.act = GELUActivation()

        self.linear_1 = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
        self.linear_2 = RowParallelLinear(
            self.hidden_size,
            self.text_config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )

    def forward(
        self,
942
        image_features: torch.Tensor | list[torch.Tensor],
943
        image_grid_thw: list[tuple[int, int, int]],
944
    ) -> torch.Tensor | list[torch.Tensor]:
945
946
947
        m1, m2 = self.merge_kernel_size
        if isinstance(image_features, (list, tuple)):
            processed_features = list()
948
            for image_feature, image_grid in zip(image_features, image_grid_thw):
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
                image_feature = self.pre_norm(image_feature)
                t, h, w = image_grid

                image_feature = rearrange(
                    image_feature,
                    "(t h p1 w p2) d -> (t h w) (p1 p2 d)",
                    t=t,
                    h=h // m1,
                    p1=m1,
                    w=w // m2,
                    p2=m2,
                )
                hidden_states, _ = self.linear_1(image_feature)
                hidden_states = self.act(hidden_states)
                hidden_states, _ = self.linear_2(hidden_states)
                processed_features.append(hidden_states)

            return processed_features

        dims = image_features.shape[:-1]
        dim = image_features.shape[-1]
        image_features = image_features.view(np.prod(dims), dim)
971
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
972
973
974
975
976
977
978
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)

        return hidden_states.view(*dims, -1)


979
980
981
def _keye_field_config(
    hf_inputs: Mapping[str, torch.Tensor],
):
982
983
984
985
986
987
988
    image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
    image_grid_sizes = image_grid_thw.prod(-1)

    video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
    video_grid_sizes = video_grid_thw.prod(-1)

    return dict(
989
990
        pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
        image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
991
992
        image_grid_thw=MultiModalFieldConfig.batched("image"),
        pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
993
994
995
            "video", video_grid_sizes
        ),
        video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes),
996
997
998
999
1000
1001
1002
        video_grid_thw=MultiModalFieldConfig.batched("video"),
    )


class KeyeMultiModalDataParser(MultiModalDataParser):
    def _parse_image_data(
        self,
1003
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={
                    "image_embeds",
                    "image_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_image_data(data)

    def _parse_video_data(
        self,
1020
        data: dict[str, torch.Tensor] | ModalityData[VideoItem],
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="video",
                required_fields={
                    "video_embeds",
                    "video_grid_thw",
                },
                fields_factory=_keye_field_config,
            )

        return super()._parse_video_data(data)


class KeyeProcessingInfo(BaseProcessingInfo):
1037
    def get_max_image_size(self) -> int:
1038
        return 9999999  # _MAX_IMAGE_SIZE
1039
1040

    def get_max_frame_per_video(self) -> int:
1041
        return 16  # _MAX_FRAMES_PER_VIDEO
1042

1043
1044
    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor
1045

1046
1047
    def get_supported_mm_limits(
        self,
1048
    ) -> Mapping[str, int | None]:
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        return {"image": None, "video": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        return {
            "image": self.get_max_image_tokens(),
            "video": self.get_max_video_tokens(seq_len),
        }

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
        image_processor,
    ) -> tuple[ImageSize, int]:
        if image_processor is None:
            image_processor = self.get_image_processor()

        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = 1

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
1087
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
1088
        else:
1089
            preprocessed_size = ImageSize(width=image_width, height=image_height)
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131

        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        image_processor,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            image_processor=image_processor,
        )
        return num_image_tokens

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
        image_processor,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            image_processor=image_processor,
        )
        return num_video_tokens

1132
1133
1134
    def get_image_size_with_most_features(
        self,
    ) -> ImageSize:
1135
        max_image_size, _ = self._get_vision_info(
1136
1137
            image_width=self.get_max_image_size(),
            image_height=self.get_max_image_size(),
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
            image_processor=None,
        )
        return max_image_size

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            image_processor=None,
        )

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )

            if next_max_tokens > max_tokens:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(self, seq_len: int) -> int:
        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.get_limit_per_prompt("image")
        max_videos = mm_config.get_limit_per_prompt("video")

        max_image_tokens = self.get_max_image_tokens() * max_images
1178
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
1179
1180
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1),
1181
            self.get_max_frame_per_video(),
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
        )

        return max(max_frames_per_video, 1)

    def get_max_video_tokens(self, seq_len: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
            num_frames=self.get_num_frames_with_most_features(seq_len),
            image_processor=None,
        )


1197
1198
1199
1200
_I = TypeVar("_I", bound=KeyeProcessingInfo)


class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_processor = self.info.get_hf_processor()
        image_token: str = hf_processor.image_token
        video_token: str = hf_processor.video_token

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1215
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1216
1217
1218
1219
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1220
1221
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(seq_len)
1222

1223
1224
1225
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1226
        mm_data = {
1227
            "image": self._get_dummy_images(
1228
1229
1230
                width=target_width,
                height=target_height,
                num_images=num_images,
1231
                overrides=image_overrides,
1232
            ),
1233
            "video": self._get_dummy_videos(
1234
1235
1236
1237
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
1238
                overrides=video_overrides,
1239
1240
1241
1242
1243
1244
            ),
        }

        return mm_data


1245
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ...
1246
1247


1248
1249
1250
1251
1252
1253
1254
1255
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        return KeyeMultiModalDataParser()

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1256
        out_mm_kwargs: MultiModalKwargsItems,
1257
1258
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1259
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        placeholder = {
            "image": vocab[hf_processor.image_token],
            "video": vocab[hf_processor.video_token],
        }

        merge_length = image_processor.merge_size**2

        def get_replacement_keye(item_idx: int, modality: str):
1271
1272
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [placeholder[modality]] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=[placeholder[modality]],
                replacement=partial(get_replacement_keye, modality=modality),
1283
1284
            )
            for modality in ("image", "video")
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _keye_field_config(hf_inputs)


1295
class BaseKeyeModule(nn.Module):
1296
1297
    merge_by_field_config = True

1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

1310
1311
1312
1313
1314
1315
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
    )
1316

1317
    @classmethod
1318
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1319
1320
1321
1322
1323
1324
1325
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

1326
1327
1328
1329
1330
1331
1332
1333
1334
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: PretrainedConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

1335
1336
1337
1338
1339
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
1340
1341
        self.visual = KeyeSiglipVisionModel(
            config.vision_config,
1342
            quant_config=quant_config,
1343
            prefix=maybe_prefix(prefix, "visual"),
1344
            attn_backend_override=attn_backend_override,
1345
        )
1346
1347

        self.mlp_AR = self._build_projector(
1348
1349
            config,
            config.vision_config,
1350
            quant_config=quant_config,
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
            prefix=maybe_prefix(prefix, "mlp_AR"),
        )

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen3ForCausalLM"],
        )

        self.make_empty_intermediate_tensors = (
1361
1362
            self.language_model.make_empty_intermediate_tensors
        )
1363

1364
    @abstractmethod
1365
1366
1367
1368
    def _build_projector(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
1369
        quant_config: QuantizationConfig | None = None,
1370
1371
        prefix: str = "",
    ) -> nn.Module:
1372
        raise ValueError("Need projector")
1373

1374
    def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
        siglip_position_ids = list()
        image_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        image_grid_thw = image_input["image_grid_thw"]
        assert image_grid_thw.ndim == 2

        for idx, thaw in enumerate(image_grid_thw):
            thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
            numel = np.prod(thw_tuple)
            image_grid_hws.append(thw_tuple)
            image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(image_position_ids)
1389
            sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
1390
1391
1392
1393
            cu_seqlens.append(cu_seqlens[-1] + numel)

        if image_input["type"] == "image_embeds":
            raise ValueError(
1394
1395
                "Image embeddings are not supported for this processing path."
            )
1396
1397
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1398
1399
1400
            siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
                pixel_values.device
            )
1401
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
1402
1403
1404
                pixel_values.device
            )
            sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device)
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419

            image_embeds = self.visual(
                pixel_values=pixel_values,
                image_grid_thw=image_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=False,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
            image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
            return image_embeds

1420
1421
1422
1423
    def _process_video_embeds(
        self,
        video_type: Literal["video_embeds", "pixel_values_videos"],
        video_grid_thw: list[torch.Tensor],
1424
1425
        pixel_values_videos: torch.Tensor | None = None,
    ) -> torch.Tensor | list[torch.Tensor]:
1426
1427
1428
1429
1430
1431
        siglip_position_ids = list()
        video_grid_hws = list()
        sample_indices = list()
        cu_seqlens = [0]

        assert video_grid_thw.ndim == 2
1432
1433
        for idx, sub_thw in enumerate(video_grid_thw):
            thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
1434
1435
1436
1437
1438
            numel = np.prod(thw_tuple)

            video_grid_hws.append(thw_tuple)
            video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
            siglip_position_ids.append(video_position_ids)
1439
            sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
1440
1441
            cu_seqlens.append(cu_seqlens[-1] + numel)

1442
        if video_type == "video_embeds":
1443
            raise ValueError(
1444
1445
                "Video embeddings are not supported for this processing path."
            )
1446
        else:
1447
            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1448
            siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
1449
1450
                pixel_values_videos.device
            )
1451
            cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
1452
1453
1454
1455
1456
                pixel_values_videos.device
            )
            sample_indices = torch.concat(sample_indices, dim=0).to(
                pixel_values_videos.device
            )
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468

            video_embeds = self.visual(
                pixel_values=pixel_values_videos,
                image_grid_thw=video_grid_hws,
                position_ids=siglip_position_ids,
                vision_return_embed_list=True,
                interpolate_pos_encoding=True,
                sample_indices=sample_indices,
                cu_seqlens=cu_seqlens,
                use_rope=True,
                window_size=-1,
            )
1469
            video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
1470
1471
1472
1473
1474
1475
            return video_embeds

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        for input_key in kwargs:
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1486
1487
1488

        return modalities

1489
1490
1491
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1492
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1493
1494
1495
1496
1497
1498
1499
1500
1501
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1502
1503
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1504
1505
1506
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1507
                multimodal_embeddings += tuple(video_embeddings)
1508
1509
1510
1511
1512
1513
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1514
1515
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1516
        **kwargs: object,
1517
    ) -> torch.Tensor | IntermediateTensors:
1518
        """Run forward pass for Keye-VL.
1519
1520
1521
1522
1523
1524
1525
1526

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
1527
1528
1529
                otherwise it will be `(seq_len,)`.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
1540

1541
1542
1543
1544
1545
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1546
    ) -> torch.Tensor | None:
1547
        return self.language_model.compute_logits(hidden_states)
1548

1549
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1550
1551
1552
1553
1554
1555
1556
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix in multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1557
1558
            connector="mlp_AR.",
            tower_model="visual.",
1559
        )
1560
1561
1562
1563
1564
1565
1566


@MULTIMODAL_REGISTRY.register_processor(
    KeyeMultiModalProcessor,
    info=KeyeProcessingInfo,
    dummy_inputs=KeyeDummyInputsBuilder,
)
1567
class KeyeForConditionalGeneration(
1568
    BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1569
1570
1571
1572
1573
):
    def _build_projector(
        self,
        text_config: PretrainedConfig,
        vision_config: PretrainedConfig,
1574
        quant_config: QuantizationConfig | None = None,
1575
1576
        prefix: str = "",
    ) -> nn.Module:
1577
1578
1579
        return Projector(text_config, vision_config, quant_config, prefix)

    def _parse_and_validate_image_input(
1580
        self, **kwargs: object
1581
    ) -> KeyeImageInputs | None:
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return KeyeImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return KeyeImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
1604
        self, **kwargs: object
1605
    ) -> KeyeVideoInputs | None:
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return KeyeVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            return KeyeVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_video_input(
1628
1629
        self, video_input: KeyeVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1630
1631
1632
1633
1634
        video_type = video_input["type"]
        video_grid_thw = video_input["video_grid_thw"]
        pixel_values_videos = video_input.get("pixel_values_videos", None)

        return tuple(
1635
1636
            self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
        )
1637
1638
1639
1640

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1641
        mm_features: list[MultiModalFeatureSpec],
1642
    ) -> tuple[torch.Tensor, int]:
1643
1644
1645
1646
1647
1648
1649
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]

1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
        if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
            video_grid_thw = video_grid_thw[0]

        def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
            """
            Split grid_thw along the t dimension.

            Args:
                grid_thw: shape [N, 3] tensor or nested list of [t, h, w].

            Returns:
                List of [1, h, w] rows, repeated t times for each original row.
            """

            if isinstance(grid_thw, list):
                grid_thw = torch.tensor(grid_thw, dtype=torch.long)

            if grid_thw.numel() == 0:
                return []

            t, hw = grid_thw[:, 0], grid_thw[:, 1:]
            ones = torch.ones_like(hw[:, :1])  # [N,1]
            out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
            return out.tolist()

        video_grid_thw = split_thw(video_grid_thw)

1677
        hf_config = self.config
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size

        image_nums = len(image_grid_thw)
        frame_nums = len(video_grid_thw)
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_frames = image_nums, frame_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + frame_nums):
            if remain_images > 0:
                try:
                    ed_image = input_tokens.index(image_token_id, st)
                except ValueError:
                    ed_image = len(input_tokens) + 1
            else:
                ed_image = len(input_tokens) + 1
            if remain_frames > 0:
                try:
                    ed_video = input_tokens.index(video_token_id, st)
                except ValueError:
                    ed_video = len(input_tokens) + 1
            else:
                ed_video = len(input_tokens) + 1

            if ed_image < ed_video:
1707
                t, h, w = image_grid_thw[image_index]
1708
1709
1710
1711
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
1712
                t, h, w = video_grid_thw[video_index]
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
                video_index += 1
                remain_frames -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = (
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
            )
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )

            t_index = (
                (
                    torch.arange(llm_grid_t)
                    .view(-1, 1)
                    .expand(-1, llm_grid_h * llm_grid_w)
                )
                .long()
                .flatten()
            )

            h_index = (
                torch.arange(llm_grid_h)
                .view(1, -1, 1)
                .expand(llm_grid_t, -1, llm_grid_w)
                .flatten()
            )
            w_index = (
                torch.arange(llm_grid_w)
                .view(1, 1, -1)
                .expand(llm_grid_t, llm_grid_h, -1)
                .flatten()
            )
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx
            )
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
            )

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

        return llm_positions, mrope_position_delta