ernie45_vl.py 59.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
24
"""Inference-only Ernie VL model compatible with HuggingFace weights."""
25

26
import itertools
27
import math
28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal
31
32
33
34
35

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
36
from einops import rearrange
37
from transformers import BatchFeature
38

39
from vllm.attention.backends.registry import AttentionBackendEnum
40
41
from vllm.attention.layers.mm_encoder_attention import (
    MMEncoderAttention,
42
)
43
from vllm.config import MultiModalConfig, VllmConfig
44
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
45
46
47
48
49
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.layernorm import RMSNorm
50
51
52
53
54
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
55
from vllm.model_executor.layers.quantization import QuantizationConfig
56
57
58
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
59
60
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
63
    MultiModalFeatureSpec,
64
65
66
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
67
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
68
69
70
71
72
73
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
74
75
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
76
from vllm.utils.tensor_schema import TensorSchema, TensorShape
77
78

from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
79
80
81
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
82
    SupportsMRoPE,
83
84
85
    SupportsMultiModal,
    SupportsPP,
)
86
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
87
88
89
90
91
92
93
94
95
96
from .vision import get_vit_attn_backend

logger = init_logger(__name__)

# === Vision Transformer === #


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist
97

98
    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
99
100
101
    dist.all_gather(
        gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
    )
102
103

    gathered_tensors_split = [
104
        torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Ernie4_5_VisionAttention(nn.Module):
    """VisionAttention using VLLM framework APIs"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
121
        quant_config: QuantizationConfig | None = None,
122
        multimodal_config: MultiModalConfig | None = None,
123
124
125
126
127
128
129
        prefix: str = "",
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
        self.hidden_size_per_attention_head = dist_utils.divide(
130
131
            projection_size, num_heads
        )
132
        self.num_attention_heads_per_partition = dist_utils.divide(
133
134
            num_heads, self.tp_size
        )
135
136
137
138
139
140
141
142

        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=True,
            quant_config=quant_config,
143
144
145
146
147
148
149
150
            prefix=f"{prefix}.qkv",
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )
151

152
153
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
154
            head_size=self.hidden_size_per_attention_head,
155
156
            multimodal_config=multimodal_config,
            prefix=f"{prefix}.attn",
157
        )
158

159
160
161
162
163
        self.apply_rotary_emb = ApplyRotaryEmb(
            enforce_enable=True,
            enable_fp32_compute=True,
        )

164
165
166
167
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
168
            qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size)
169
170
171
172
173
174

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
175
176
177
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
178
179
180
181
182
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
183
184
185
186
187
188
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
189
190
191
192
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
193
194
195
196
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
197
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
198
199
200
201
202
203
204
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)

205
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
206
        if rotary_pos_emb is not None:
207
            qk_concat = torch.cat([q, k], dim=0)
208
209
210
211
212
            qk_rotated = self.apply_rotary_emb(
                qk_concat,
                rotary_pos_emb.cos(),
                rotary_pos_emb.sin(),
            )
213
            q, k = torch.chunk(qk_rotated, 2, dim=0)
214

215
216
217
218
219
220
221
222
        output = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous()
223
224
225
226
227
228
229
230
231
232
233

        output, _ = self.proj(context_layer)
        return output


class Ernie4_5_VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        act_layer: type[nn.Module] = QuickGELU,
234
        quant_config: QuantizationConfig | None = None,
235
236
237
        prefix: str = "",
    ):
        super().__init__()
238
239
240
241
242
243
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
244
        self.act = act_layer()
245
246
247
248
249
250
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Ernie4_5_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
        act_layer: type[nn.Module] = QuickGELU,
266
267
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
268
        multimodal_config: MultiModalConfig | None = None,
269
270
271
272
273
274
275
276
277
278
        prefix: str = "",
    ) -> None:
        super().__init__()

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

279
280
281
282
283
        self.attn = Ernie4_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
284
            multimodal_config=multimodal_config,
285
286
            prefix=f"{prefix}.attn",
        )
287

288
289
290
291
292
293
294
        self.mlp = Ernie4_5_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
295
296

    def forward(
297
298
299
300
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
301
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    ) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Ernie4_5_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        in_channels: int = 3,
        embed_dim: int = 1280,
        prefix="",
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

326
327
328
        self.proj = nn.Linear(
            in_channels * patch_size * patch_size, embed_dim, bias=False
        )
329
330
331
332
333
334
335
336
337
338
339
340

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.to(target_dtype)
        hidden_states = self.proj(hidden_states)

        return hidden_states


class Ernie4_5_VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
341
342
343
        self.inv_freq = 1.0 / theta ** (
            torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim
        )
344
345

    def forward(self, seqlen: int) -> torch.Tensor:
346
347
348
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
349
350
351
352
353
354
355
356
357
        freqs = torch.outer(input=seq, vec2=self.inv_freq)
        return freqs


class Ernie4_5_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config,
        norm_eps: float = 1e-6,
358
        quant_config: QuantizationConfig | None = None,
359
        multimodal_config: MultiModalConfig | None = None,
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        prefix: str = "",
    ) -> None:
        super().__init__()
        patch_size = vision_config.patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio

        self.spatial_merge_size = spatial_merge_size
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        self.patch_embed = Ernie4_5_VisionPatchEmbed(
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
            prefix=f"{prefix}.patch_embed",
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2)

387
388
389
390
391
392
393
394
        self.blocks = nn.ModuleList(
            [
                Ernie4_5_VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
395
                    multimodal_config=multimodal_config,
396
397
398
399
400
401
402
403
404
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )

        assert hidden_size == embed_dim, (
            "vit's config.hidden must be equal to config.embed_dim"
        )
405
406
        self.ln = nn.LayerNorm(hidden_size, eps=1e-6)

407
408
409
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend if multimodal_config else None
        )
410
        self.attn_backend = get_vit_attn_backend(
411
412
413
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
414
        )
415
416
417
418
419
420
421
422
423
424
425
426
427
428

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
450
451
452
453
454
455
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

456
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
457
        max_seqlen = None
458
        if (
459
460
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
461
        ):
462
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
463
        return max_seqlen
464

465
466
467
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
    ) -> torch.Tensor:
468
469
470
471
472
        hidden_states = self.patch_embed(hidden_states)

        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)

473
474
475
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
476

477
        zeros = cu_seqlens.new_zeros(1)
478
        if num_pad > 0:
479
            cu_seqlens = torch.cat([zeros, cu_seqlens, zeros])
480
481
            cu_seqlens[-1] = cu_seqlens[-2] + num_pad
        else:
482
            cu_seqlens = torch.cat([zeros, cu_seqlens])
483
484
485
486
487

        # add batch size
        if hidden_states.ndim == 2:
            hidden_states = hidden_states.unsqueeze(dim=1)

488
489
        # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511

        for i, blk in enumerate(self.blocks):
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
            )

        final_output = self.ln(hidden_states)

        if final_output.ndim == 3:
            final_output = final_output.squeeze(dim=1)

        return final_output

    def load_weights(self, weights) -> set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            param = params_dict[name]
512
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
513
514
515
516
517
518
519
520
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


# === Vision Inputs === #


521
class Ernie4_5_VLImagePixelInputs(TensorSchema):
522
    """
523
524
525
526
527
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
528
    """
529

530
531
532
533
    type: Literal["pixel_values"]

    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
534
535
536
537
538


Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs


539
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
540
    """
541
542
543
544
545
546
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * temporal_patch_size * patch_size *
              patch_size
547
    """
548

549
550
551
    type: Literal["pixel_values_videos"]
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
552
553


554
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
555
556
557
558

# === Vision Processor === #


559
def round_by_factor(number: int | float, factor: int) -> int:
560
561
562
    return round(number / factor) * factor


563
def ceil_by_factor(number: int | float, factor: int) -> int:
564
565
566
    return math.ceil(number / factor) * factor


567
def floor_by_factor(number: int | float, factor: int) -> int:
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    return math.floor(number / factor) * factor


def smart_resize(
    height: int,
    width: int,
    factor: int = 28,
    min_pixels: int = 4 * 28 * 28,
    max_pixels: int = 16384 * 28 * 28,
):
    MAX_RATIO = 200
    if max(height, width) / min(height, width) > MAX_RATIO:
        if height > width:
            new_width = max(factor, round_by_factor(width, factor))
            new_height = floor_by_factor(new_width * MAX_RATIO, factor)
        else:
            new_height = max(factor, round_by_factor(height, factor))
            new_width = floor_by_factor(new_height * MAX_RATIO, factor)

        height = new_height
        width = new_width

    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)

    if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
602
603
604
605
606
        raise ValueError(
            f"Invalid h_bar={h_bar}, w_bar={w_bar}: "
            f"h_bar * w_bar must be >= min_pixels ({min_pixels}) "
            f"and <= max_pixels ({max_pixels})."
        )
607
608
609
610
611

    return h_bar, w_bar


class VariableResolutionResamplerModel(nn.Module):
612
613
614
615
616
617
618
619
620
    def __init__(
        self,
        in_dim,
        out_dim,
        spatial_conv_size,
        temporal_conv_size,
        config,
        prefix: str = "",
    ) -> None:
621
622
623
624
625
626
627
628
629
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.config = config
        self.spatial_conv_size = spatial_conv_size
        self.temporal_conv_size = temporal_conv_size
        self.use_temporal_conv = config.use_temporal_conv

        # compress 2d conv(picture) to 1d
630
        self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size
631
        # compress 3d conv(video) to 1d
632
633
634
635
636
637
        self.temporal_dim = (
            self.in_dim
            * self.spatial_conv_size
            * self.spatial_conv_size
            * self.temporal_conv_size
        )
638
639
640
641
642
643

        self.spatial_linear1 = ColumnParallelLinear(
            self.spatial_dim,
            self.spatial_dim,
            bias=True,
            gather_output=True,
644
            quant_config=getattr(config, "quant_config", None),
645
646
647
648
649
650
651
652
653
654
            prefix=f"{prefix}.spatial_linear1",
        )

        self.spatial_gelu = nn.GELU()

        self.spatial_linear2 = ColumnParallelLinear(
            self.spatial_dim,
            self.spatial_dim,
            bias=True,
            gather_output=True,
655
            quant_config=getattr(config, "quant_config", None),
656
657
658
659
660
661
662
663
664
665
666
            prefix=f"{prefix}.spatial_linear2",
        )

        self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)

        if self.use_temporal_conv:
            self.temporal_linear1 = ColumnParallelLinear(
                self.temporal_dim,
                self.spatial_dim,
                bias=True,
                gather_output=True,
667
                quant_config=getattr(config, "quant_config", None),
668
669
670
671
672
673
674
675
676
677
                prefix=f"{prefix}.temporal_linear1",
            )

            self.temporal_gelu = nn.GELU()

            self.temporal_linear2 = ColumnParallelLinear(
                self.spatial_dim,
                self.spatial_dim,
                bias=True,
                gather_output=True,
678
                quant_config=getattr(config, "quant_config", None),
679
680
681
682
683
684
685
686
687
688
                prefix=f"{prefix}.temporal_linear2",
            )

            self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)

        self.mlp = ColumnParallelLinear(
            self.spatial_dim,
            self.out_dim,
            bias=True,
            gather_output=True,
689
            quant_config=getattr(config, "quant_config", None),
690
691
692
            prefix=f"{prefix}.mlp",
        )

693
694
695
        self.after_norm = RMSNorm(
            hidden_size=out_dim, eps=getattr(config, "rms_norm_eps", 1e-6)
        )
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

    def spatial_conv_reshape(self, x, spatial_conv_size):
        S, C = x.shape
        x = x.reshape([-1, C * (spatial_conv_size**2)])
        return x

    def forward(self, x, grid_thw):
        def fwd_spatial(x):
            x = self.spatial_conv_reshape(x, self.spatial_conv_size)

            x, _ = self.spatial_linear1(x)
            x = self.spatial_gelu(x)
            x, _ = self.spatial_linear2(x)
            x = self.spatial_norm(x)

            return x

        def fwd_placeholder(x, grid_thw, to_tensor=False):
            grid_thw_cpu = grid_thw.cpu().numpy()
            grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:]
716
            grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2)
717

718
719
720
721
            tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2)
            batch_offset = np.empty(
                tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype
            )
722
723
724
725
726
            batch_offset[0] = 0
            batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1]

            slice_offsets = []
            for temporoal_size, spatial_size, b_offset in zip(
727
728
                grid_t, grid_hw_after_conv, batch_offset
            ):
729
730
731
732
733
                for temp_offset in range(0, temporoal_size, 2):
                    slice_offsets.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
734
735
736
737
738
                        )
                    )
            slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to(
                x.device
            )
739
740
741

            slice_offsets2 = []
            for temporoal_size, spatial_size, b_offset in zip(
742
743
744
745
746
                grid_t, grid_hw_after_conv, batch_offset
            ):
                for temp_offset in range(
                    1 if temporoal_size > 1 else 0, temporoal_size, 2
                ):
747
748
749
750
                    slice_offsets2.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
751
752
753
754
755
                        )
                    )
            slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to(
                x.device
            )
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780

            x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets)
            x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2)
            x = torch.concat([x_timestep_1, x_timestep_2], dim=-1)
            return x

        def fwd_temporal(x):
            x, _ = self.temporal_linear1(x)
            x = self.temporal_gelu(x)
            x, _ = self.temporal_linear2(x)
            x = self.temporal_norm(x)
            return x

        def fwd_mlp(x):
            x, _ = self.mlp(x)
            x = self.after_norm(x)
            return x

        x = fwd_spatial(x)
        if self.use_temporal_conv:
            x = fwd_placeholder(x, grid_thw)
            x = fwd_temporal(x)
        x = fwd_mlp(x)
        return x

781
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
782
783
784
785
786
787
788
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            if name not in params_dict:
                continue
            param = params_dict[name]
789
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.model_config.hf_config

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(use_fast=True, **kwargs)

    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor

805
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
806
807
        return {"image": None, "video": None}

808
809
810
811
812
813
814
815
816
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

817
818
819
820
821
822
823
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
824
        image_processor: Any | None,
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    ) -> tuple[ImageSize, int]:
        if image_processor is None:
            image_processor = self.get_image_processor()
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config

        patch_size = vision_config.patch_size
        spatial_conv_size = hf_config.spatial_conv_size
        temporal_conv_size = hf_config.temporal_conv_size

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * spatial_conv_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
843
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
844
        else:
845
            preprocessed_size = ImageSize(width=image_width, height=image_height)
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860

        grid_t = max(num_frames // temporal_conv_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (spatial_conv_size**2)

        return preprocessed_size, num_vision_tokens

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
861
        image_processor: Any | None,
862
863
864
865
866
867
868
869
870
871
872
873
874
875
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            image_processor=image_processor,
        )
        return num_image_tokens

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
876
        image_processor: Any | None,
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            image_processor=image_processor,
        )
        return num_video_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
            image_processor=None,
        )
        return max_image_size

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_image_tokens = self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            image_processor=None,
        )
        return num_image_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )

            if next_max_tokens > max_tokens:
                break

            num_frames = next_num_frames

        # If the number of frames is odd, discard one frame.
        if num_frames % 2 != 0:
            num_frames -= 1

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
938
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
939
        max_frames_per_video = max_total_frames // max(max_videos, 1)
940
941
942
943
944
945
946
947
948
949
950
951
952

        return max(max_frames_per_video, 2)

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
953
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
954
955
956
957
            image_processor=None,
        )


958
class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]):
959
960
961
962
963
    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(
            video_needs_metadata=True,
        )

964
965
966
967
968
969
970
971
    def _pixel_values_norm(
        self,
        pixel_values: torch.Tensor,
        mm_kwargs: object,
    ) -> torch.Tensor:
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        image_processor = self.info.get_image_processor(**mm_kwargs)
972
973
974
975
976
977
978
979
980
        image_mean_tensor = torch.tensor(
            image_processor.image_mean, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        image_std_tensor = torch.tensor(
            image_processor.image_std, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        rescale_factor = torch.tensor(
            image_processor.rescale_factor, dtype=torch.float32
        )
981
982
        patch_size_squared = vision_config.patch_size**2

983
984
985
986
987
988
        image_mean_tensor = image_mean_tensor.squeeze([-2, -1]).repeat_interleave(
            patch_size_squared, -1
        )
        image_std_tensor = image_std_tensor.squeeze([-2, -1]).repeat_interleave(
            patch_size_squared, -1
        )
989
990
991
992
993
994

        if not image_mean_tensor.is_contiguous():
            image_mean_tensor = image_mean_tensor.contiguous()
        if not image_std_tensor.is_contiguous():
            image_std_tensor = image_std_tensor.contiguous()

995
996
997
        pixel_values = (
            rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor
        ) / image_std_tensor
998
        pixel_values = pixel_values.to(hf_config.dtype)
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        return pixel_values

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # when the prompt is not empty but the multimodal data is empty,
        # directly invoke the tokenizer.
        if "images" not in mm_data and "videos" not in mm_data and prompt != "":
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt)
1013
1014
1015
            tokenizer_output = BatchFeature(
                dict(input_ids=[prompt_ids]), tensor_type="pt"
            )
1016
1017
1018
1019
1020
1021
            return tokenizer_output

        if "images" not in mm_data:
            mm_data["images"] = []
        if "videos" not in mm_data:
            mm_data["videos"] = []
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038

        # Check if HF processor supports video metadata
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        supports_video_metadata = getattr(
            hf_processor, "supports_video_metadata", False
        )

        if mm_data["videos"] and not supports_video_metadata:
            # Old HF processor, unwrap tuple to pure frames
            logger.warning_once(
                "HF processor doesn't support video metadata. "
                "Timestamps will NOT be rendered. Please upgrade the model."
            )
            mm_data["videos"] = [
                v[0] if isinstance(v, tuple) else v for v in mm_data["videos"]
            ]

1039
        processor_output = self.info.ctx.call_hf_processor(
1040
            hf_processor,
1041
            dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]),
1042
1043
1044
1045
1046
            dict(**mm_kwargs, **tok_kwargs),
        )

        # Divide the processor_output into two modalities: image and video.
        if processor_output is not None:
1047
            pixel_values = processor_output["images"]
1048
            if pixel_values is not None:
1049
1050
1051
                processor_output["images"] = self._pixel_values_norm(
                    pixel_values, mm_kwargs
                )
1052
1053
1054
1055
1056
            for key in list(processor_output.keys()):
                if processor_output[key] is None:
                    del processor_output[key]
                    continue
                if key == "grid_thw":
1057
1058
                    grid_thw = processor_output["grid_thw"]
                    pixel_values_all = processor_output["images"]
1059
1060
1061
1062
1063
1064
                    # Identify elements where the first
                    # dimension is greater than 1 and
                    # treat them as the video modality
                    mask = grid_thw[:, 0] > 1
                    processor_output["video_grid_thw"] = grid_thw[mask]
                    processor_output["image_grid_thw"] = grid_thw[~mask]
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
                    image_patch_num = (
                        processor_output["image_grid_thw"].prod(dim=1).sum()
                    )
                    processor_output["pixel_values"] = pixel_values_all[
                        :image_patch_num
                    ]
                    processor_output["pixel_values_videos"] = pixel_values_all[
                        image_patch_num:
                    ]
                    del processor_output["images"]
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087

        return processor_output

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        before_placeholder = {
            "image": "<|image@placeholder|>",
1088
            "video": "<|video@placeholder|>",
1089
1090
1091
1092
1093
        }

        after_placeholder = {
            # image and video have same placeholder
            "image": "<|IMAGE_PLACEHOLDER|>",
1094
            "video": "<|IMAGE_PLACEHOLDER|>",
1095
1096
1097
1098
1099
1100
1101
1102
1103
        }

        merge_length = hf_processor.spatial_conv_size**2

        def get_replacement_ernie45vl(item_idx: int, modality: str):
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)
            if modality == "video":
1104
1105
1106
1107
1108
                num_tokens = (
                    int(grid_thw.prod())
                    // hf_processor.temporal_conv_size
                    // merge_length
                )
1109
1110
1111
1112
1113
1114
1115
1116
            else:
                num_tokens = int(grid_thw.prod()) // merge_length
            return after_placeholder[modality] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=before_placeholder[modality],
1117
1118
1119
                replacement=partial(get_replacement_ernie45vl, modality=modality),
            )
            for modality in ("image", "video")
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_grid_sizes = image_grid_thw.prod(-1)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
1135
1136
                "image", image_grid_sizes
            ),
1137
1138
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
1139
1140
                "video", video_grid_sizes
            ),
1141
1142
1143
1144
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )


1145
class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]):
1146
1147
1148
1149
1150
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
        prompt = ""
        for i in range(num_images):
1151
1152
1153
            prompt += (
                f"Picture {i + 1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
            )
1154
1155

        for i in range(num_videos):
1156
            prompt += f"Video {i + 1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
1157
1158
1159
1160
1161
1162
        return prompt

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1163
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1164
1165
1166
1167
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1168
1169
1170
1171
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1172

1173
1174
1175
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1176
        return {
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
                overrides=video_overrides,
            ),
1190
1191
        }

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
        overrides: VideoDummyOptions | None = None,
    ):
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
                        overrides.num_frames,
                        num_frames,
                    )
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
                        overrides.height,
                        height,
                    )
                height = min(height, overrides.height)
        num_frames = max(num_frames, 2)  # ernie4.5-vl requires at least 2 frames

        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
                "frames_indices": [i for i in range(num_frames)],
                "video_backend": "opencv",
                "do_sample_frames": False,
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)
        return video_items

1246
1247
1248
1249

@MULTIMODAL_REGISTRY.register_processor(
    Ernie4_5VLMultiModalProcessor,
    info=Ernie4_5_VLProcessingInfo,
1250
1251
1252
    dummy_inputs=Ernie4_5_VLDummyInputsBuilder,
)
class Ernie4_5_VLMoeForConditionalGeneration(
1253
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1254
):
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
            # model.resampler_model.-> language_model.model.resampler_model.
            # language_model.model.resampler_model. -> resampler_model.
            "language_model.model.resampler_model.": "resampler_model.",
        },
        # resampler_weight_mappings
        orig_to_new_substr={
            "spatial_linear.0.": "spatial_linear1.",
            "spatial_linear.2.": "spatial_linear2.",
            "spatial_linear.3.": "spatial_norm.",
            "temporal_linear.0.": "temporal_linear1.",
            "temporal_linear.2.": "temporal_linear2.",
            "temporal_linear.3.": "temporal_norm.",
1284
1285
        },
    )
1286
1287

    @classmethod
1288
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        if modality.startswith("image"):
            return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
        if modality.startswith("video"):
            return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_model = Ernie4_5_VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
            quant_config=quant_config,
1309
            multimodal_config=multimodal_config,
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        self.language_model = Ernie4_5_VLMoeForCausalLM(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.resampler_model = VariableResolutionResamplerModel(
            self.config.pixel_hidden_size,
            self.config.hidden_size,
            self.config.spatial_conv_size,
            self.config.temporal_conv_size,
            config=self.config,
1324
1325
            prefix=maybe_prefix(prefix, "resampler_model"),
        )
1326
1327
1328

        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
1329
1330
            self.language_model.make_empty_intermediate_tensors
        )
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
        if getattr(self.config, "im_patch_id", None):
            visual_token_ids = [
                token_id
                for token_id in [
                    self.config.im_patch_id,
                    getattr(self.config, "image_start_token_id", None),
                    getattr(self.config, "image_end_token_id", None),
                    getattr(self.config, "video_start_token_id", None),
                    getattr(self.config, "video_end_token_id", None),
                ]
                if token_id is not None
            ]
            self._visual_token_ids_tensor_cache = torch.tensor(
                visual_token_ids, dtype=torch.long
            )
        else:
            self._visual_token_ids_tensor_cache = None
1348
1349
1350
1351

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1352
    ) -> torch.Tensor | None:
1353
        """compute logits"""
1354
        return self.language_model.compute_logits(hidden_states)
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365

    def _vision_forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        if grid_thw is not None:
            grid_thw = grid_thw[grid_thw > 0]
            if grid_thw.numel() % 3 != 0:
                raise ValueError(
                    f"grid_thw has {grid_thw.numel()} elements after filtering,"
1366
1367
                    "which is not divisible by 3."
                )
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
            grid_thw = grid_thw.reshape(-1, 3)
            # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]]
            grid_thw = F.pad(
                torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0),
                [1, 0, 0, 0],
                value=1,
            )
        image_features = self.vision_model(pixel_values, grid_thw)
        return image_features

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
1379
1380
        """Set mask for visual tokens (image/video patches and delimiters)."""
        if self._visual_token_ids_tensor_cache is None:
1381
            self.visual_token_mask = None
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
            return
        # Create tensor on the correct device
        visual_token_ids_tensor = self._visual_token_ids_tensor_cache.to(
            device=input_ids.device,
            dtype=input_ids.dtype,
        )

        self.visual_token_mask = torch.isin(input_ids, visual_token_ids_tensor).reshape(
            -1, 1
        )
1392

1393
    def get_mrope_input_positions(
1394
        self,
1395
        input_tokens: list[int],
1396
        mm_features: list[MultiModalFeatureSpec],
1397
    ) -> tuple[torch.Tensor, int]:
1398
1399
1400
1401
1402
1403
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
1404

1405
        hf_config = self.config
1406
1407
1408
1409
1410
1411
1412
        image_token_id = hf_config.im_patch_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_conv_size = hf_config.spatial_conv_size
        temporal_conv_size = hf_config.temporal_conv_size
        llm_pos_ids_list: list = []

1413
        if image_grid_thw or video_grid_thw:
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                enumerate(input_token_type), lambda x: x[1]
            ):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = (
                    llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                )
                if modality_type == "image":
1445
                    t, h, w = image_grid_thw[mm_data_idx]
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t,
                        h // spatial_conv_size,
                        w // spatial_conv_size,
                    )

                    t_index = (
                        torch.arange(llm_grid_t)
                        .view(-1, 1)
                        .expand(-1, llm_grid_h * llm_grid_w)
                        .flatten()
                    )
                    h_index = (
                        torch.arange(llm_grid_h)
                        .view(1, -1, 1)
                        .expand(llm_grid_t, -1, llm_grid_w)
                        .flatten()
                    )
                    w_index = (
                        torch.arange(llm_grid_w)
                        .view(1, 1, -1)
                        .expand(llm_grid_t, llm_grid_h, -1)
                        .flatten()
                    )
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx
                    )
                    mm_data_idx += 1

                elif modality_type == "video":
1476
                    t, h, w = video_grid_thw[mm_data_idx]
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t // temporal_conv_size,
                        h // spatial_conv_size,
                        w // spatial_conv_size,
                    )

                    for t_idx in range(llm_grid_t):
                        t_index = (
                            torch.tensor(t_idx)
                            .view(-1, 1)
                            .expand(-1, llm_grid_h * llm_grid_w)
                            .flatten()
                        )
                        h_index = (
                            torch.arange(llm_grid_h)
                            .view(1, -1, 1)
                            .expand(1, -1, llm_grid_w)
                            .flatten()
                        )
                        w_index = (
                            torch.arange(llm_grid_w)
                            .view(1, 1, -1)
                            .expand(1, llm_grid_h, -1)
                            .flatten()
                        )
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx
                        )

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
                    )
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return llm_positions, mrope_position_delta

1524
1525
1526
1527
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def _parse_and_validate_image_input(
1528
        self, **kwargs: object
1529
    ) -> Ernie4_5_VLImageInputs | None:
1530
1531
1532
1533
1534
1535
1536
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None:
            return None

        if pixel_values is not None:
1537
1538
1539
1540
1541
            return Ernie4_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1542
1543

    def _parse_and_validate_video_input(
1544
        self, **kwargs: object
1545
    ) -> Ernie4_5_VLVideoInputs | None:
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None:
            return None

        if pixel_values_videos is not None:
            return Ernie4_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
1560
1561
        self, image_input: Ernie4_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1562
1563
1564
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1565
1566
1567
1568
        pixel_values = image_input["pixel_values"].type(self.vision_model.dtype)
        image_features = self._vision_forward(
            pixel_values=pixel_values, grid_thw=grid_thw
        )
1569
1570
1571
1572
1573
1574
1575
1576
        image_embeds = self.resampler_model(image_features, grid_thw)

        merge_size = self.vision_model.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return image_embeds.split(sizes.tolist())

    def _process_video_input(
1577
1578
        self, video_input: Ernie4_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1579
1580
1581
1582
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        pixel_values_videos = video_input["pixel_values_videos"].type(
1583
1584
1585
1586
1587
            self.vision_model.dtype
        )
        video_features = self._vision_forward(
            pixel_values=pixel_values_videos, grid_thw=grid_thw
        )
1588
1589
1590
        video_embeds = self.resampler_model(video_features, grid_thw)

        merge_size = self.vision_model.spatial_merge_size
1591
1592
1593
1594
1595
        sizes = (
            (grid_thw.prod(-1) // self.config.temporal_conv_size)
            // merge_size
            // merge_size
        )
1596
1597
1598
1599
1600
1601
1602
1603
1604

        return video_embeds.split(sizes.tolist())

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1615
1616
1617

        return modalities

1618
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1619
1620
1621
1622
1623
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1624
        # tensor corresponding to a multimodal data item (image or video).
1625
1626
1627
1628
1629
1630
1631
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1632
1633
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1634
1635
1636
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1637
                multimodal_embeddings += tuple(video_embeddings)
1638
1639
1640

        return multimodal_embeddings

1641
    def embed_input_ids(
1642
1643
        self,
        input_ids: torch.Tensor,
1644
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1645
        *,
1646
        is_multimodal: torch.Tensor | None = None,
1647
        handle_oov_mm_token: bool = False,
1648
    ) -> torch.Tensor:
1649
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
1650
1651
1652
1653
            self._set_visual_token_mask(input_ids)

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
1654
            return super().embed_input_ids(input_ids)
1655

1656
        return super().embed_input_ids(
1657
1658
1659
1660
1661
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
1662
1663
1664
1665
1666

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1667
1668
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
        **kwargs,
    ):
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        if self.visual_token_mask is not None:
            if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]:
1680
                padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0]
1681
1682
1683
1684
                # right pad False
                pad = torch.zeros(
                    (padding_len, self.visual_token_mask.shape[1]),
                    dtype=self.visual_token_mask.dtype,
1685
1686
1687
                    device=self.visual_token_mask.device,
                )
                self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0)
1688

1689
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
1690
1691
1692
1693
1694
1695
1696
1697
1698
            self.visual_token_mask = None

        hidden_states = self.language_model.model(
            **forward_kwargs,
            **kwargs,
        )

        return hidden_states

1699
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1700
1701
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)