ernie45_vl.py 60.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Erine VL model compatible with HuggingFace weights."""
25

26
import itertools
27
import math
28
from collections.abc import Callable, Iterable, Mapping, Sequence
29
from functools import partial
30
from typing import Annotated, Any, Literal
31
32
33
34
35
36

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
37
from transformers import BatchFeature, PretrainedConfig
38

39
from vllm.attention.backends.registry import _Backend
40
41
42
43
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
44
from vllm.config import VllmConfig
45
from vllm.config.multimodal import BaseDummyOptions
46
47
48
49
50
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.layernorm import RMSNorm
51
52
53
54
55
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
56
57
58
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
59
60
61
62
63
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
64
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
65
66
67
68
69
70
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
71
from vllm.multimodal.profiling import BaseDummyInputsBuilder
72
from vllm.platforms import current_platform
73
from vllm.sequence import IntermediateTensors
74
from vllm.utils.tensor_schema import TensorSchema, TensorShape
75
76

from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
77
78
79
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
80
    SupportsMRoPE,
81
82
83
    SupportsMultiModal,
    SupportsPP,
)
84
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
85
86
87
88
89
90
91
92
93
94
95
96
97
from .vision import get_vit_attn_backend

logger = init_logger(__name__)

# === Vision Transformer === #


def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
98
99
100
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
101
102


103
104
105
def apply_rotary_emb_torch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
106
107
108
109
110
111
112
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
113
114
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
115
    sin = repeat(
116
117
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
118
119
    return torch.cat(
        [
120
121
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
122
123
124
125
126
        ],
        dim=-1,
    )


127
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
128
129
130
131
132
133
134
135
136
137
138
139
140
    t_ = t.float()
    cos = freqs.cos()
    sin = freqs.sin()
    apply_rotary_emb = apply_rotary_emb_torch
    if current_platform.is_cuda():
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
    output = apply_rotary_emb(t_, cos, sin).type_as(t)
    return output


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist
141

142
    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
143
144
145
    dist.all_gather(
        gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
    )
146
147

    gathered_tensors_split = [
148
        torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Ernie4_5_VisionAttention(nn.Module):
    """VisionAttention using VLLM framework APIs"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
165
        quant_config: QuantizationConfig | None = None,
166
        prefix: str = "",
167
        attn_backend_override: _Backend | None = None,
168
169
170
171
172
173
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
        self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
        self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
        self.hidden_size_per_attention_head = dist_utils.divide(
174
175
            projection_size, num_heads
        )
176
        self.num_attention_heads_per_partition = dist_utils.divide(
177
178
            num_heads, self.tp_size
        )
179
180
181
182
183
184
185
186

        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=True,
            quant_config=quant_config,
187
188
189
190
191
192
193
194
            prefix=f"{prefix}.qkv",
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
        )
195
196

        # Detect attention implementation.
197
198
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
199
            dtype=torch.get_default_dtype(),
200
            attn_backend_override=attn_backend_override,
201
        )
202
203

        self.use_upstream_fa = False
204

205
206
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
207
208
209
                self.attn_backend,
                self.use_upstream_fa,
            )
210
        )
211

212
        if self.attn_backend not in {
213
214
215
216
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
217
218
219
220
221
        }:
            raise RuntimeError(
                f"Ernie45-VL does not support {self.attn_backend} backend now."
            )
        self.is_flash_attn_backend = self.attn_backend in {
222
223
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
224
225
226
227
228
229
        }

    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape
        if self.tp_size > 1:
230
            qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size)
231
232
233
234
235
236

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim]
        if self.tp_size > 1:
237
238
239
            splitter = partial(
                dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
            )
240
241
242
243
244
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
            v = splitter(v)[self.tp_rank]

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
245
246
247
248
249
250
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
251
252
253
254
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
255
256
257
258
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
259
260
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
261
262
263
264
265
266
267
268
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
        batch_size = q.shape[1]

269
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
270
        if rotary_pos_emb is not None:
271
272
273
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
274
275
276
277

        if self.is_flash_attn_backend:
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            output = self.flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )

            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
293
294
295
296
297
298
299
300
301
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
302
303
304
305
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
306
307
308
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
309
310
311
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
312
313
314
315
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

316
317
318
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
319
320

            context_layer = xops.memory_efficient_attention_forward(
321
322
323
324
325
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
326
327
328
329
330
331
332
333
334
335
336

        output, _ = self.proj(context_layer)
        return output


class Ernie4_5_VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        act_layer: type[nn.Module] = QuickGELU,
337
        quant_config: QuantizationConfig | None = None,
338
339
340
        prefix: str = "",
    ):
        super().__init__()
341
342
343
344
345
346
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
347
        self.act = act_layer()
348
349
350
351
352
353
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Ernie4_5_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
        act_layer: type[nn.Module] = QuickGELU,
369
370
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
371
        prefix: str = "",
372
        attn_backend_override: _Backend | None = None,
373
374
375
376
377
378
379
380
381
    ) -> None:
        super().__init__()

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

382
383
384
385
386
387
        self.attn = Ernie4_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
388
            attn_backend_override=attn_backend_override,
389
        )
390

391
392
393
394
395
396
397
        self.mlp = Ernie4_5_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
398
399

    def forward(
400
401
402
403
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
404
405
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    ) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Ernie4_5_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        in_channels: int = 3,
        embed_dim: int = 1280,
        prefix="",
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

431
432
433
        self.proj = nn.Linear(
            in_channels * patch_size * patch_size, embed_dim, bias=False
        )
434
435
436
437
438
439
440
441
442
443
444
445

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.to(target_dtype)
        hidden_states = self.proj(hidden_states)

        return hidden_states


class Ernie4_5_VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
446
447
448
        self.inv_freq = 1.0 / theta ** (
            torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim
        )
449
450

    def forward(self, seqlen: int) -> torch.Tensor:
451
452
453
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
454
455
456
457
458
459
460
461
462
        freqs = torch.outer(input=seq, vec2=self.inv_freq)
        return freqs


class Ernie4_5_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config,
        norm_eps: float = 1e-6,
463
        quant_config: QuantizationConfig | None = None,
464
        prefix: str = "",
465
        attn_backend_override: _Backend | None = None,
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    ) -> None:
        super().__init__()
        patch_size = vision_config.patch_size
        spatial_merge_size = vision_config.spatial_merge_size
        in_channels = vision_config.in_channels
        hidden_size = vision_config.hidden_size
        embed_dim = vision_config.embed_dim
        depth = vision_config.depth
        num_heads = vision_config.num_heads
        mlp_ratio = vision_config.mlp_ratio

        self.spatial_merge_size = spatial_merge_size
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        self.patch_embed = Ernie4_5_VisionPatchEmbed(
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
            prefix=f"{prefix}.patch_embed",
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2)

492
493
494
495
496
497
498
499
500
        self.blocks = nn.ModuleList(
            [
                Ernie4_5_VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
501
                    attn_backend_override=attn_backend_override,
502
503
504
505
506
507
508
509
                )
                for layer_idx in range(depth)
            ]
        )

        assert hidden_size == embed_dim, (
            "vit's config.hidden must be equal to config.embed_dim"
        )
510
511
        self.ln = nn.LayerNorm(hidden_size, eps=1e-6)

512
        self.attn_backend = get_vit_attn_backend(
513
514
515
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
516
517
518
519
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
520
            self.attn_backend = _Backend.FLASH_ATTN
521
522
523
524
525
526
527
528
529
530
531
532
533
534

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
556
557
558
559
560
561
562
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def compute_attn_mask_seqlen(
563
        self, cu_seqlens: torch.Tensor
564
    ) -> tuple[int | None, list[int] | None]:
565
        max_seqlen, seqlens = None, None
566
567
568
569
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
570
571
572
573
574
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

575
576
577
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
    ) -> torch.Tensor:
578
579
580
581
582
        hidden_states = self.patch_embed(hidden_states)

        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)

583
584
585
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
586

587
        zeros = cu_seqlens.new_zeros(1)
588
        if num_pad > 0:
589
            cu_seqlens = torch.cat([zeros, cu_seqlens, zeros])
590
591
            cu_seqlens[-1] = cu_seqlens[-2] + num_pad
        else:
592
            cu_seqlens = torch.cat([zeros, cu_seqlens])
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

        # add batch size
        if hidden_states.ndim == 2:
            hidden_states = hidden_states.unsqueeze(dim=1)

        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)

        for i, blk in enumerate(self.blocks):
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )

        final_output = self.ln(hidden_states)

        if final_output.ndim == 3:
            final_output = final_output.squeeze(dim=1)

        return final_output

    def load_weights(self, weights) -> set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            param = params_dict[name]
623
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
624
625
626
627
628
629
630
631
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


# === Vision Inputs === #


632
class Ernie4_5_VLImagePixelInputs(TensorSchema):
633
    """
634
635
636
637
638
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
639
    """
640

641
642
643
644
    type: Literal["pixel_values"]

    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
645
646
647
648
649


Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs


650
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
651
    """
652
653
654
655
656
657
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * temporal_patch_size * patch_size *
              patch_size
658
    """
659

660
661
662
    type: Literal["pixel_values_videos"]
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
    video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
663
664


665
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
666
667
668
669

# === Vision Processor === #


670
def round_by_factor(number: int | float, factor: int) -> int:
671
672
673
    return round(number / factor) * factor


674
def ceil_by_factor(number: int | float, factor: int) -> int:
675
676
677
    return math.ceil(number / factor) * factor


678
def floor_by_factor(number: int | float, factor: int) -> int:
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    return math.floor(number / factor) * factor


def smart_resize(
    height: int,
    width: int,
    factor: int = 28,
    min_pixels: int = 4 * 28 * 28,
    max_pixels: int = 16384 * 28 * 28,
):
    MAX_RATIO = 200
    if max(height, width) / min(height, width) > MAX_RATIO:
        if height > width:
            new_width = max(factor, round_by_factor(width, factor))
            new_height = floor_by_factor(new_width * MAX_RATIO, factor)
        else:
            new_height = max(factor, round_by_factor(height, factor))
            new_width = floor_by_factor(new_height * MAX_RATIO, factor)

        height = new_height
        width = new_width

    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)

    if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
        raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")

    return h_bar, w_bar


class VariableResolutionResamplerModel(nn.Module):
719
720
721
722
723
724
725
726
727
    def __init__(
        self,
        in_dim,
        out_dim,
        spatial_conv_size,
        temporal_conv_size,
        config,
        prefix: str = "",
    ) -> None:
728
729
730
731
732
733
734
735
736
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.config = config
        self.spatial_conv_size = spatial_conv_size
        self.temporal_conv_size = temporal_conv_size
        self.use_temporal_conv = config.use_temporal_conv

        # compress 2d conv(picture) to 1d
737
        self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size
738
        # compress 3d conv(video) to 1d
739
740
741
742
743
744
        self.temporal_dim = (
            self.in_dim
            * self.spatial_conv_size
            * self.spatial_conv_size
            * self.temporal_conv_size
        )
745
746
747
748
749
750

        self.spatial_linear1 = ColumnParallelLinear(
            self.spatial_dim,
            self.spatial_dim,
            bias=True,
            gather_output=True,
751
            quant_config=getattr(config, "quant_config", None),
752
753
754
755
756
757
758
759
760
761
            prefix=f"{prefix}.spatial_linear1",
        )

        self.spatial_gelu = nn.GELU()

        self.spatial_linear2 = ColumnParallelLinear(
            self.spatial_dim,
            self.spatial_dim,
            bias=True,
            gather_output=True,
762
            quant_config=getattr(config, "quant_config", None),
763
764
765
766
767
768
769
770
771
772
773
            prefix=f"{prefix}.spatial_linear2",
        )

        self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)

        if self.use_temporal_conv:
            self.temporal_linear1 = ColumnParallelLinear(
                self.temporal_dim,
                self.spatial_dim,
                bias=True,
                gather_output=True,
774
                quant_config=getattr(config, "quant_config", None),
775
776
777
778
779
780
781
782
783
784
                prefix=f"{prefix}.temporal_linear1",
            )

            self.temporal_gelu = nn.GELU()

            self.temporal_linear2 = ColumnParallelLinear(
                self.spatial_dim,
                self.spatial_dim,
                bias=True,
                gather_output=True,
785
                quant_config=getattr(config, "quant_config", None),
786
787
788
789
790
791
792
793
794
795
                prefix=f"{prefix}.temporal_linear2",
            )

            self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)

        self.mlp = ColumnParallelLinear(
            self.spatial_dim,
            self.out_dim,
            bias=True,
            gather_output=True,
796
            quant_config=getattr(config, "quant_config", None),
797
798
799
            prefix=f"{prefix}.mlp",
        )

800
801
802
        self.after_norm = RMSNorm(
            hidden_size=out_dim, eps=getattr(config, "rms_norm_eps", 1e-6)
        )
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822

    def spatial_conv_reshape(self, x, spatial_conv_size):
        S, C = x.shape
        x = x.reshape([-1, C * (spatial_conv_size**2)])
        return x

    def forward(self, x, grid_thw):
        def fwd_spatial(x):
            x = self.spatial_conv_reshape(x, self.spatial_conv_size)

            x, _ = self.spatial_linear1(x)
            x = self.spatial_gelu(x)
            x, _ = self.spatial_linear2(x)
            x = self.spatial_norm(x)

            return x

        def fwd_placeholder(x, grid_thw, to_tensor=False):
            grid_thw_cpu = grid_thw.cpu().numpy()
            grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:]
823
            grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2)
824

825
826
827
828
            tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2)
            batch_offset = np.empty(
                tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype
            )
829
830
831
832
833
            batch_offset[0] = 0
            batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1]

            slice_offsets = []
            for temporoal_size, spatial_size, b_offset in zip(
834
835
                grid_t, grid_hw_after_conv, batch_offset
            ):
836
837
838
839
840
                for temp_offset in range(0, temporoal_size, 2):
                    slice_offsets.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
841
842
843
844
845
                        )
                    )
            slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to(
                x.device
            )
846
847
848

            slice_offsets2 = []
            for temporoal_size, spatial_size, b_offset in zip(
849
850
851
852
853
                grid_t, grid_hw_after_conv, batch_offset
            ):
                for temp_offset in range(
                    1 if temporoal_size > 1 else 0, temporoal_size, 2
                ):
854
855
856
857
                    slice_offsets2.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
858
859
860
861
862
                        )
                    )
            slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to(
                x.device
            )
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887

            x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets)
            x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2)
            x = torch.concat([x_timestep_1, x_timestep_2], dim=-1)
            return x

        def fwd_temporal(x):
            x, _ = self.temporal_linear1(x)
            x = self.temporal_gelu(x)
            x, _ = self.temporal_linear2(x)
            x = self.temporal_norm(x)
            return x

        def fwd_mlp(x):
            x, _ = self.mlp(x)
            x = self.after_norm(x)
            return x

        x = fwd_spatial(x)
        if self.use_temporal_conv:
            x = fwd_placeholder(x, grid_thw)
            x = fwd_temporal(x)
        x = fwd_mlp(x)
        return x

888
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
889
890
891
892
893
894
895
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            if name not in params_dict:
                continue
            param = params_dict[name]
896
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.model_config.hf_config

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(use_fast=True, **kwargs)

    def get_image_processor(self, **kwargs: object):
        return self.get_hf_processor(**kwargs).image_processor

912
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
913
914
        return {"image": None, "video": None}

915
916
917
918
919
920
921
922
923
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts)
        return {"image": max_image_tokens, "video": max_video_tokens}

924
925
926
927
928
929
930
    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 1,
        do_resize: bool = True,
931
        image_processor: Any | None,
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    ) -> tuple[ImageSize, int]:
        if image_processor is None:
            image_processor = self.get_image_processor()
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config

        patch_size = vision_config.patch_size
        spatial_conv_size = hf_config.spatial_conv_size
        temporal_conv_size = hf_config.temporal_conv_size

        if do_resize:
            resized_height, resized_width = smart_resize(
                height=image_height,
                width=image_width,
                factor=patch_size * spatial_conv_size,
                min_pixels=image_processor.min_pixels,
                max_pixels=image_processor.max_pixels,
            )
950
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
951
        else:
952
            preprocessed_size = ImageSize(width=image_width, height=image_height)
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967

        grid_t = max(num_frames // temporal_conv_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (spatial_conv_size**2)

        return preprocessed_size, num_vision_tokens

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
968
        image_processor: Any | None,
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            image_processor=image_processor,
        )
        return num_image_tokens

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
983
        image_processor: Any | None,
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            image_processor=image_processor,
        )
        return num_video_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
        max_image_size, _ = self._get_vision_info(
            image_width=9999999,
            image_height=9999999,
            image_processor=None,
        )
        return max_image_size

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_image_tokens = self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            image_processor=None,
        )
        return num_image_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
                image_processor=None,
            )

            if next_max_tokens > max_tokens:
                break

            num_frames = next_num_frames

        # If the number of frames is odd, discard one frame.
        if num_frames % 2 != 0:
            num_frames -= 1

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
1045
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
1046
        max_frames_per_video = max_total_frames // max(max_videos, 1)
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059

        return max(max_frames_per_video, 2)

    def get_max_video_tokens(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_video_tokens(
            image_width=target_width,
            image_height=target_height,
1060
            num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
1061
1062
1063
1064
            image_processor=None,
        )


1065
class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]):
1066
1067
1068
1069
1070
1071
1072
1073
    def _pixel_values_norm(
        self,
        pixel_values: torch.Tensor,
        mm_kwargs: object,
    ) -> torch.Tensor:
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        image_processor = self.info.get_image_processor(**mm_kwargs)
1074
1075
1076
1077
1078
1079
1080
1081
1082
        image_mean_tensor = torch.tensor(
            image_processor.image_mean, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        image_std_tensor = torch.tensor(
            image_processor.image_std, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        rescale_factor = torch.tensor(
            image_processor.rescale_factor, dtype=torch.float32
        )
1083
1084
        patch_size_squared = vision_config.patch_size**2

1085
1086
1087
1088
1089
1090
        image_mean_tensor = image_mean_tensor.squeeze([-2, -1]).repeat_interleave(
            patch_size_squared, -1
        )
        image_std_tensor = image_std_tensor.squeeze([-2, -1]).repeat_interleave(
            patch_size_squared, -1
        )
1091
1092
1093
1094
1095
1096

        if not image_mean_tensor.is_contiguous():
            image_mean_tensor = image_mean_tensor.contiguous()
        if not image_std_tensor.is_contiguous():
            image_std_tensor = image_std_tensor.contiguous()

1097
1098
1099
        pixel_values = (
            rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor
        ) / image_std_tensor
1100
        pixel_values = pixel_values.to(hf_config.dtype)
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        return pixel_values

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # when the prompt is not empty but the multimodal data is empty,
        # directly invoke the tokenizer.
        if "images" not in mm_data and "videos" not in mm_data and prompt != "":
            tokenizer = self.info.get_tokenizer()
            prompt_ids = tokenizer.encode(prompt)
1115
1116
1117
            tokenizer_output = BatchFeature(
                dict(input_ids=[prompt_ids]), tensor_type="pt"
            )
1118
1119
1120
1121
1122
1123
1124
1125
            return tokenizer_output

        if "images" not in mm_data:
            mm_data["images"] = []
        if "videos" not in mm_data:
            mm_data["videos"] = []
        processor_output = self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
1126
            dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]),
1127
1128
1129
1130
1131
            dict(**mm_kwargs, **tok_kwargs),
        )

        # Divide the processor_output into two modalities: image and video.
        if processor_output is not None:
1132
            pixel_values = processor_output["images"]
1133
            if pixel_values is not None:
1134
1135
1136
                processor_output["images"] = self._pixel_values_norm(
                    pixel_values, mm_kwargs
                )
1137
1138
1139
1140
1141
            for key in list(processor_output.keys()):
                if processor_output[key] is None:
                    del processor_output[key]
                    continue
                if key == "grid_thw":
1142
1143
                    grid_thw = processor_output["grid_thw"]
                    pixel_values_all = processor_output["images"]
1144
1145
1146
1147
1148
1149
                    # Identify elements where the first
                    # dimension is greater than 1 and
                    # treat them as the video modality
                    mask = grid_thw[:, 0] > 1
                    processor_output["video_grid_thw"] = grid_thw[mask]
                    processor_output["image_grid_thw"] = grid_thw[~mask]
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
                    image_patch_num = (
                        processor_output["image_grid_thw"].prod(dim=1).sum()
                    )
                    processor_output["pixel_values"] = pixel_values_all[
                        :image_patch_num
                    ]
                    processor_output["pixel_values_videos"] = pixel_values_all[
                        image_patch_num:
                    ]
                    del processor_output["images"]
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172

        return processor_output

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        before_placeholder = {
            "image": "<|image@placeholder|>",
1173
            "video": "<|video@placeholder|>",
1174
1175
1176
1177
1178
        }

        after_placeholder = {
            # image and video have same placeholder
            "image": "<|IMAGE_PLACEHOLDER|>",
1179
            "video": "<|IMAGE_PLACEHOLDER|>",
1180
1181
1182
1183
1184
1185
1186
1187
1188
        }

        merge_length = hf_processor.spatial_conv_size**2

        def get_replacement_ernie45vl(item_idx: int, modality: str):
            out_item = out_mm_kwargs[modality][item_idx]
            grid_thw = out_item[f"{modality}_grid_thw"].data
            assert isinstance(grid_thw, torch.Tensor)
            if modality == "video":
1189
1190
1191
1192
1193
                num_tokens = (
                    int(grid_thw.prod())
                    // hf_processor.temporal_conv_size
                    // merge_length
                )
1194
1195
1196
1197
1198
1199
1200
1201
            else:
                num_tokens = int(grid_thw.prod()) // merge_length
            return after_placeholder[modality] * num_tokens

        return [
            PromptReplacement(
                modality=modality,
                target=before_placeholder[modality],
1202
1203
1204
                replacement=partial(get_replacement_ernie45vl, modality=modality),
            )
            for modality in ("image", "video")
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
        image_grid_sizes = image_grid_thw.prod(-1)

        video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
        video_grid_sizes = video_grid_thw.prod(-1)

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
1220
1221
                "image", image_grid_sizes
            ),
1222
1223
            image_grid_thw=MultiModalFieldConfig.batched("image"),
            pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
1224
1225
                "video", video_grid_sizes
            ),
1226
1227
1228
1229
            video_grid_thw=MultiModalFieldConfig.batched("video"),
        )


1230
class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]):
1231
1232
1233
1234
1235
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)
        prompt = ""
        for i in range(num_images):
1236
1237
1238
            prompt += (
                f"Picture {i + 1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
            )
1239
1240

        for i in range(num_videos):
1241
            prompt += f"Video {i + 1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
1242
1243
1244
1245
1246
1247
        return prompt

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1248
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1249
1250
1251
1252
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1253
1254
1255
1256
        target_width, target_height = self.info.get_image_size_with_most_features()
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts
        )
1257

1258
1259
1260
        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1261
        return {
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
                overrides=video_overrides,
            ),
1275
1276
1277
1278
1279
1280
        }


@MULTIMODAL_REGISTRY.register_processor(
    Ernie4_5VLMultiModalProcessor,
    info=Ernie4_5_VLProcessingInfo,
1281
1282
1283
    dummy_inputs=Ernie4_5_VLDummyInputsBuilder,
)
class Ernie4_5_VLMoeForConditionalGeneration(
1284
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1285
):
1286
    merge_by_field_config = True
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
            # model.resampler_model.-> language_model.model.resampler_model.
            # language_model.model.resampler_model. -> resampler_model.
            "language_model.model.resampler_model.": "resampler_model.",
        },
        # resampler_weight_mappings
        orig_to_new_substr={
            "spatial_linear.0.": "spatial_linear1.",
            "spatial_linear.2.": "spatial_linear2.",
            "spatial_linear.3.": "spatial_norm.",
            "temporal_linear.0.": "temporal_linear1.",
            "temporal_linear.2.": "temporal_linear2.",
            "temporal_linear.3.": "temporal_norm.",
1317
1318
        },
    )
1319
1320

    @classmethod
1321
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
        if modality.startswith("image"):
            return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
        if modality.startswith("video"):
            return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

1338
1339
1340
1341
1342
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
1343
1344
1345
1346
1347
        self.vision_model = Ernie4_5_VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_model"),
1348
            attn_backend_override=attn_backend_override,
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
        )

        self.language_model = Ernie4_5_VLMoeForCausalLM(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.resampler_model = VariableResolutionResamplerModel(
            self.config.pixel_hidden_size,
            self.config.hidden_size,
            self.config.spatial_conv_size,
            self.config.temporal_conv_size,
            config=self.config,
1362
1363
            prefix=maybe_prefix(prefix, "resampler_model"),
        )
1364
1365
1366

        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
1367
1368
            self.language_model.make_empty_intermediate_tensors
        )
1369
1370
1371
1372

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1373
    ) -> torch.Tensor | None:
1374
        """compute logits"""
1375
        return self.language_model.compute_logits(hidden_states)
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386

    def _vision_forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        if grid_thw is not None:
            grid_thw = grid_thw[grid_thw > 0]
            if grid_thw.numel() % 3 != 0:
                raise ValueError(
                    f"grid_thw has {grid_thw.numel()} elements after filtering,"
1387
1388
                    "which is not divisible by 3."
                )
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
            grid_thw = grid_thw.reshape(-1, 3)
            # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]]
            grid_thw = F.pad(
                torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0),
                [1, 0, 0, 0],
                value=1,
            )
        image_features = self.vision_model(pixel_values, grid_thw)
        return image_features

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        if getattr(self.config, "im_patch_id", None) is not None:
1401
1402
1403
            self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape(
                -1, 1
            )
1404
1405
1406
        else:
            self.visual_token_mask = None

1407
    def get_mrope_input_positions(
1408
        self,
1409
1410
        input_tokens: list[int],
        hf_config: PretrainedConfig,
1411
1412
        image_grid_thw: list[list[int]] | torch.Tensor,
        video_grid_thw: list[list[int]] | torch.Tensor,
1413
        context_len: int = 0,
1414
1415
1416
        seq_len: int | None = None,
        second_per_grid_ts: list[float] | None = None,
        audio_feature_lengths: torch.Tensor | None = None,
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value for Ernie VL."""

        image_token_id = hf_config.im_patch_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_conv_size = hf_config.spatial_conv_size
        temporal_conv_size = hf_config.temporal_conv_size
        llm_pos_ids_list: list = []

        if not (image_grid_thw is None and video_grid_thw is None):
            if isinstance(image_grid_thw, torch.Tensor):
                image_grid_thw = image_grid_thw.tolist()

            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                enumerate(input_token_type), lambda x: x[1]
            ):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = (
                    llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                )
                if modality_type == "image":
                    t, h, w = (
                        image_grid_thw[mm_data_idx][0],
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t,
                        h // spatial_conv_size,
                        w // spatial_conv_size,
                    )

                    t_index = (
                        torch.arange(llm_grid_t)
                        .view(-1, 1)
                        .expand(-1, llm_grid_h * llm_grid_w)
                        .flatten()
                    )
                    h_index = (
                        torch.arange(llm_grid_h)
                        .view(1, -1, 1)
                        .expand(llm_grid_t, -1, llm_grid_w)
                        .flatten()
                    )
                    w_index = (
                        torch.arange(llm_grid_w)
                        .view(1, 1, -1)
                        .expand(llm_grid_t, llm_grid_h, -1)
                        .flatten()
                    )
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx
                    )
                    mm_data_idx += 1

                elif modality_type == "video":
                    t, h, w = (
                        video_grid_thw[mm_data_idx][0],
                        video_grid_thw[mm_data_idx][1],
                        video_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t // temporal_conv_size,
                        h // spatial_conv_size,
                        w // spatial_conv_size,
                    )

                    for t_idx in range(llm_grid_t):
                        t_index = (
                            torch.tensor(t_idx)
                            .view(-1, 1)
                            .expand(-1, llm_grid_h * llm_grid_w)
                            .flatten()
                        )
                        h_index = (
                            torch.arange(llm_grid_h)
                            .view(1, -1, 1)
                            .expand(1, -1, llm_grid_w)
                            .flatten()
                        )
                        w_index = (
                            torch.arange(llm_grid_w)
                            .view(1, 1, -1)
                            .expand(1, llm_grid_h, -1)
                            .flatten()
                        )
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx
                        )

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
                    )
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        llm_positions = llm_positions[:, context_len:seq_len]
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return llm_positions, mrope_position_delta

1551
1552
1553
1554
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def _parse_and_validate_image_input(
1555
        self, **kwargs: object
1556
    ) -> Ernie4_5_VLImageInputs | None:
1557
1558
1559
1560
1561
1562
1563
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None:
            return None

        if pixel_values is not None:
1564
1565
1566
1567
1568
            return Ernie4_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
1569
1570

    def _parse_and_validate_video_input(
1571
        self, **kwargs: object
1572
    ) -> Ernie4_5_VLVideoInputs | None:
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)

        if pixel_values_videos is None:
            return None

        if pixel_values_videos is not None:
            return Ernie4_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
1587
1588
        self, image_input: Ernie4_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
1589
1590
1591
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

1592
1593
1594
1595
        pixel_values = image_input["pixel_values"].type(self.vision_model.dtype)
        image_features = self._vision_forward(
            pixel_values=pixel_values, grid_thw=grid_thw
        )
1596
1597
1598
1599
1600
1601
1602
1603
        image_embeds = self.resampler_model(image_features, grid_thw)

        merge_size = self.vision_model.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return image_embeds.split(sizes.tolist())

    def _process_video_input(
1604
1605
        self, video_input: Ernie4_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1606
1607
1608
1609
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        pixel_values_videos = video_input["pixel_values_videos"].type(
1610
1611
1612
1613
1614
            self.vision_model.dtype
        )
        video_features = self._vision_forward(
            pixel_values=pixel_values_videos, grid_thw=grid_thw
        )
1615
1616
1617
        video_embeds = self.resampler_model(video_features, grid_thw)

        merge_size = self.vision_model.spatial_merge_size
1618
1619
1620
1621
1622
        sizes = (
            (grid_thw.prod(-1) // self.config.temporal_conv_size)
            // merge_size
            // merge_size
        )
1623
1624
1625
1626
1627
1628
1629
1630
1631

        return video_embeds.split(sizes.tolist())

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
            if (
                input_key in ("pixel_values", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "videos" not in modalities
            ):
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1642
1643
1644
1645

        return modalities

    def get_multimodal_embeddings(
1646
        self, **kwargs: object
1647
    ) -> MultiModalEmbeddings | None:
1648
1649
1650
1651
1652
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1653
        # tensor corresponding to a multimodal data item (image or video).
1654
1655
1656
1657
1658
1659
1660
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1661
1662
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1663
1664
1665
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_video_input(video_input)
1666
                multimodal_embeddings += tuple(video_embeddings)
1667
1668
1669
1670
1671
1672

        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
1673
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1674
        *,
1675
        is_multimodal: torch.Tensor | None = None,
1676
        handle_oov_mm_token: bool = False,
1677
    ) -> torch.Tensor:
1678
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
            self._set_visual_token_mask(input_ids)

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
1691
1692
1693
1694
1695

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1696
1697
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
        **kwargs,
    ):
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        if self.visual_token_mask is not None:
            if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]:
1709
                padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0]
1710
1711
1712
1713
                # right pad False
                pad = torch.zeros(
                    (padding_len, self.visual_token_mask.shape[1]),
                    dtype=self.visual_token_mask.dtype,
1714
1715
1716
                    device=self.visual_token_mask.device,
                )
                self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0)
1717

1718
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
1719
1720
1721
1722
1723
1724
1725
1726
1727
            self.visual_token_mask = None

        hidden_states = self.language_model.model(
            **forward_kwargs,
            **kwargs,
        )

        return hidden_states

1728
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1729
1730
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)