glm4_1v.py 65.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The ZhipuAI Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""

zhuwenwen's avatar
zhuwenwen committed
29
import os
30
import itertools
31
import math
32
from collections.abc import Callable, Iterable, Mapping, Sequence
33
from functools import partial
34
from typing import Annotated, Any, Literal, TypeAlias
35
36
37
38
39
40

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
41
from transformers import BatchFeature, Glm4vProcessor
Yuxuan Zhang's avatar
Yuxuan Zhang committed
42
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
43
from transformers.models.glm4v.image_processing_glm4v import (
44
45
46
47
    Glm4vImageProcessor,
    smart_resize,
)
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
48
49
from transformers.video_utils import VideoMetadata

50
from vllm.attention.backends.registry import AttentionBackendEnum
51
52
53
54
from vllm.attention.layers.mm_encoder_attention import (
    MMEncoderAttention,
)
from vllm.config import MultiModalConfig, VllmConfig
55
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
56
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
57
58
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
59
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
60
from vllm.model_executor.layers.layernorm import RMSNorm
61
62
63
64
65
66
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
67
from vllm.model_executor.layers.quantization import QuantizationConfig
68
from vllm.model_executor.layers.rotary_embedding import get_rope
69
70
71
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
72
73
74
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
75
76
from vllm.multimodal.inputs import (
    MultiModalDataDict,
77
    MultiModalFeatureSpec,
78
79
80
81
82
83
84
85
86
87
88
89
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
90
91
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
92
from vllm.utils.tensor_schema import TensorSchema, TensorShape
93
94

from ..layers.activation import SiluAndMul
95
96
97
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
98
    SupportsMRoPE,
99
100
101
    SupportsMultiModal,
    SupportsPP,
)
102
from .qwen2_vl import _create_qwen2vl_field_factory
103
104
105
106
107
108
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
109
110
111
112
from .vision import (
    get_vit_attn_backend,
    run_dp_sharded_mrope_vision_model,
)
113
114
115
116
117
118
119
120
121

logger = init_logger(__name__)

# For profile run
_MAX_FRAMES_PER_VIDEO = 600

# === Vision Inputs === #


122
class Glm4vImagePixelInputs(TensorSchema):
123
    """
124
125
126
127
128
    Dimensions:
        - np: Number of patches
        - cpp: Number of channels * patch_size * patch_size
        - ni: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
129
    """
130

131
    type: Literal["pixel_values"] = "pixel_values"
132

133
134
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
135
136


137
class Glm4vImageEmbeddingInputs(TensorSchema):
138
    """
139
140
141
142
143
    Dimensions:
        - f: Number of image features (varies based on image resolution)
        - h: Hidden size (must match language model backbone)
        - n: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
144
    """
145

146
147
148
149
    type: Literal["image_embeds"] = "image_embeds"

    image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
150
151


152
Glm4vImageInputs: TypeAlias = Glm4vImagePixelInputs | Glm4vImageEmbeddingInputs
153
154


155
class Glm4vVideoPixelInputs(TensorSchema):
156
    """
157
158
159
160
161
    Dimensions:
        - np: Number of patches
        - ctpp: Number of channels * temporal_patch_size *
            patch_size * patch_size
        - f: Number of frames
162
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
163
          video, grid_h, grid_w)
164
    """
165

166
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
167

168
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
169
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
170
171


172
class Glm4vVideoEmbeddingInputs(TensorSchema):
173
    """
174
175
176
    Dimensions:
        - p: Number of video patches across all frames
        - h: Hidden size (must match language model backbone)
177
        - f: Number of frames
178
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
179
          video, grid_h, grid_w)
180
    """
181

182
    type: Literal["video_embeds"] = "video_embeds"
183

184
    video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
185
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
186
187


188
Glm4vVideoInputs: TypeAlias = Glm4vVideoPixelInputs | Glm4vVideoEmbeddingInputs
189

190
# ==== Vision Encoder ==== #
191
192
193
194
195
196
197
198


class Glm4vVisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
199
        quant_config: QuantizationConfig | None = None,
200
        multimodal_config: MultiModalConfig | None = None,
201
        prefix: str = "",
202
203
    ):
        super().__init__()
204
205
206
207
208
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=in_features,
            output_sizes=[hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist

    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
    dist.all_gather(
        gathered_tensors,
        local_tensor,
        group=parallel_state.get_tp_group().device_group,
    )

    gathered_tensors_split = [
246
        torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Glm4vVisionAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
261
        quant_config: QuantizationConfig | None = None,
262
        multimodal_config: MultiModalConfig | None = None,
263
264
265
266
        prefix: str = "",
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
267
268
269
270
271
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
272
273
274
275
276
277
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = (
            0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
        )
278
        self.hidden_size_per_attention_head = dist_utils.divide(
279
280
            projection_size, num_heads
        )
281
        self.num_attention_heads_per_partition = dist_utils.divide(
282
283
            num_heads, self.tp_size
        )
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=False,
            quant_config=quant_config,
            # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
            prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            bias=False,
            disable_tp=use_data_parallel,
        )
304

305
306
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
307
            head_size=self.hidden_size_per_attention_head,
308
            multimodal_config=multimodal_config,
309
        )
310

311
        self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
312
313
314

    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
zhuwenwen's avatar
zhuwenwen committed
315
316
        if qkv.dim() == 2:
            qkv = qkv.unsqueeze(1)   # 在 dim=1 加 batch 维度
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
333
334
335
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
336
337
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
338
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
339
340
341
342
343
344
345
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)

346
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
347
        if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
348
349
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
350
351
352
353
            qk_rotated = self.apply_rotary_emb(
                qk_concat,
                rotary_pos_emb_cos,
                rotary_pos_emb_sin,
354
            )
355
            q, k = torch.chunk(qk_rotated, 2, dim=0)
356

357
358
359
360
361
362
363
364
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
365
366
367
368
369
370
371
372
373
374
375

        output, _ = self.proj(context_layer)
        return output


class Glm4vVisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
376
377
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
378
        multimodal_config: MultiModalConfig | None = None,
379
380
381
382
383
384
385
386
387
388
389
390
        prefix: str = "",
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Glm4vVisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
391
            multimodal_config=multimodal_config,
392
393
394
395
396
397
398
            prefix=f"{prefix}.attn",
        )
        self.mlp = Glm4vVisionMLP(
            dim,
            mlp_hidden_dim,
            bias=False,
            quant_config=quant_config,
399
            multimodal_config=multimodal_config,
400
            prefix=f"{prefix}.mlp",
401
402
403
        )

    def forward(
404
405
406
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
407
408
        rotary_pos_emb_cos: torch.Tensor,
        rotary_pos_emb_sin: torch.Tensor,
409
        max_seqlen: int | None = None,  # Only used for Flash Attention
410
    ) -> torch.Tensor:
411
        x_attn = self.attn(
412
413
            self.norm1(x),
            cu_seqlens=cu_seqlens,
414
415
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
416
417
            max_seqlen=max_seqlen,
        )
zhuwenwen's avatar
zhuwenwen committed
418
419
420
421
422
423
424
        if x_attn.dim() == 2:
            x_attn = x_attn.unsqueeze(1)
        elif x_attn.dim() == 1:
            x_attn = x_attn.unsqueeze(1).unsqueeze(2)

        assert x_attn.dim() == 3, f"x_attn must be 3D, got {x_attn.shape}"

425
        x_fused_norm, residual = self.norm2(x, residual=x_attn)
zhuwenwen's avatar
zhuwenwen committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        if x_fused_norm.dim() == 3 and x_fused_norm.shape[1] == 1:
            mlp_in = x_fused_norm.squeeze(1)
            restore_3d = True

        elif x_fused_norm.dim() == 2:
            mlp_in = x_fused_norm
            restore_3d = False

        else:
            raise RuntimeError(f"Unexpected x_fused_norm shape {x_fused_norm.shape}, expect (N,D) or (N,1,D)")
        out = self.mlp(mlp_in)
        if restore_3d:
            out = out.unsqueeze(1)
        assert out.shape == residual.shape, \
            f"residual {residual.shape} vs mlp_out {out.shape} mismatch"
        x = residual + out
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

        return x


class Glm4vVisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 1,
        in_channels: int = 3,
        hidden_size: int = 1536,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
460
        self.proj = Conv3dLayer(
461
462
463
464
465
466
467
468
469
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
470
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
zhuwenwen's avatar
zhuwenwen committed
471
472
        if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
            x = x.to(memory_format=torch.channels_last_3d)
473
474
475
476
477
478
479
480
481
        x = self.proj(x).view(L, self.hidden_size)
        return x


class Glm4vPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
482
        quant_config: QuantizationConfig | None = None,
483
        multimodal_config: MultiModalConfig | None = None,
484
        bias: bool = False,
485
        prefix: str = "",
486
487
    ) -> None:
        super().__init__()
488
489
490
491
492
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
493
        self.hidden_size = d_model
494
495
496
497
498
499
500
501
502
        self.proj = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=bias,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
503
        self.post_projection_norm = nn.LayerNorm(self.hidden_size)
504
        self.gate_up_proj = MergedColumnParallelLinear(
505
506
507
508
            input_size=self.hidden_size,
            output_sizes=[context_dim] * 2,
            bias=bias,
            quant_config=quant_config,
509
            prefix=f"{prefix}.gate_up_proj",
510
            disable_tp=use_data_parallel,
511
        )
512
        self.down_proj = RowParallelLinear(
513
514
515
516
            context_dim,
            self.hidden_size,
            bias=bias,
            quant_config=quant_config,
517
            prefix=f"{prefix}.down_proj",
518
            disable_tp=use_data_parallel,
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        )
        self.act_fn = SiluAndMul()
        self.extra_activation_func = nn.GELU()

    def forward(self, x: torch.Tensor):
        x, _ = self.proj(x)
        x = self.extra_activation_func(self.post_projection_norm(x))
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Glm4vVisionEmbeddings(nn.Module):
    def __init__(self, config: Glm4vVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

540
        self.num_patches = (self.image_size // self.patch_size) ** 2
541
        self.num_positions = self.num_patches
542
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
543
544
545
546
547
548
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

549
550
551
    def forward(
        self, embeddings, lengths, image_shapes, h_coords, w_coords
    ) -> torch.Tensor:
552
553
554
555
556
557
558
559
560
561
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
562
563
564
            adapted_pos_embed = torch.empty(
                0, hidden_size, device=device, dtype=pos_embed_weight.dtype
            )
565
566
567
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
568
                lengths = torch.tensor(lengths, device=device, dtype=torch.long)
569
            if not isinstance(image_shapes, torch.Tensor):
570
571
572
                image_shapes = torch.tensor(
                    image_shapes, device=device, dtype=torch.long
                )
573
574
575
576

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
577
578
579
580
581
582
            pos_embed_2d = (
                pos_embed_weight.view(orig_size, orig_size, hidden_size)
                .permute(2, 0, 1)
                .unsqueeze(0)
                .to(device=device, dtype=torch.float32)
            )
583
584

            # Calculate target dimensions for each patch
585
586
587
588
589
590
591
592
593
594
            # Add bounds checking for data parallel mode
            if len(lengths) > image_shapes.shape[0]:
                # In data parallel mode, some GPUs might not have all
                # image shapes
                # Use available image shapes, cycling if necessary
                target_h_list = []
                target_w_list = []
                for i in range(len(lengths)):
                    # Cycle through available shapes
                    shape_idx = i % image_shapes.shape[0]
595
596
597
598
599
600
601
602
                    target_h_list.append(image_shapes[shape_idx, 1].repeat(lengths[i]))
                    target_w_list.append(image_shapes[shape_idx, 2].repeat(lengths[i]))
                target_h = torch.cat(target_h_list).to(
                    device=device, dtype=torch.float32
                )
                target_w = torch.cat(target_w_list).to(
                    device=device, dtype=torch.float32
                )
603
            else:
604
605
606
607
608
609
                target_h = torch.cat(
                    [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
                target_w = torch.cat(
                    [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
610
611
612
613
614
615
616
617

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
618
            grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
619
620
621
622
623
624
625
626
627
628
629
630

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(
                pos_embed_2d,
                grid,
                mode="bicubic",
                align_corners=False,
                padding_mode="border",
            )

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = (
631
632
633
634
635
                interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
            )
            adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
                embeddings.device
            )
636
637
638
639
640
641
642
643
644
645
646

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class Glm4vVisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Glm4vVisionConfig,
        norm_eps: float = 1e-6,
647
        quant_config: QuantizationConfig | None = None,
648
        multimodal_config: MultiModalConfig | None = None,
649
650
651
652
        prefix: str = "",
    ) -> None:
        super().__init__()

653
654
        assert multimodal_config is not None, "multimodal_config must be provided"

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads

        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.out_hidden_size = vision_config.out_hidden_size

        self.patch_embed = Glm4vVisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
675
676
677
678
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
679
            rope_parameters={"partial_rotary_factor": 0.5},
680
        )
681
682
683
684
685
686
687
688
        self.blocks = nn.ModuleList(
            [
                Glm4vVisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.out_hidden_size,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
689
                    multimodal_config=multimodal_config,
690
691
692
693
694
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )
695
696
697
698
        self.merger = Glm4vPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=vision_config.intermediate_size,
            quant_config=quant_config,
699
            multimodal_config=multimodal_config,
700
            bias=False,
701
            prefix=f"{prefix}.merger",
702
703
704
        )
        self.embeddings = Glm4vVisionEmbeddings(vision_config)

705
706
707
        self.post_conv_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
708
        self.downsample = Conv2dLayer(
709
710
711
712
713
            in_channels=vision_config.hidden_size,
            out_channels=vision_config.out_hidden_size,
            kernel_size=vision_config.spatial_merge_size,
            stride=vision_config.spatial_merge_size,
        )
714
715
716
        self.post_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
717

718
        self.attn_backend = get_vit_attn_backend(
719
720
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
721
            attn_backend_override=multimodal_config.mm_encoder_attn_backend,
722
        )
723
724
725
726
727
728
729
730
731

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

732
733
734
    def rot_pos_emb(
        self, grid_thw: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
735
736
737
738
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
760
761
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
762
763
764
765

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

766
767
        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)
768
        return cos_combined, sin_combined, pos_ids
769
770
771
772

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
773
    ) -> torch.Tensor | None:
774
        max_seqlen = None
775
        if (
776
777
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
778
        ):
779
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
780
        return max_seqlen
781
782
783
784

    def forward(
        self,
        x: torch.Tensor,
785
        grid_thw: torch.Tensor | list[list[int]],
786
    ) -> torch.Tensor:
787
788
        if isinstance(grid_thw, list):
            grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
789

790
791
792
793
794
795
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)
        x = self.post_conv_layernorm(x)

        # compute position embedding
796
797
798
        rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
            grid_thw
        )
799
        # compute cu_seqlens
800
801
802
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
803
804
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
        cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
805

806
807
808
        # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
809
810
811
        x = self.embeddings(
            x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
        )
812
813
814
815
816
817
818

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
819
820
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
821
822
823
824
825
826
                max_seqlen=max_seqlen,
            )

        # adapter
        x = self.post_layernorm(x)

827
        x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
828
829
830
831
832
833
        x = x.permute(0, 3, 1, 2)
        x = self.downsample(x).view(-1, self.out_hidden_size)
        x = self.merger(x)

        return x

834
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
858
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
859
860
861
862
863
864
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Glm4vProcessingInfo(BaseProcessingInfo):
865
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
866
867
        return {"image": None, "video": 1}

868
869
    def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
870

871
872
    def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 16,
        do_resize: bool = True,
        max_image_pixels: int = 28 * 28 * 2 * 30000,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
        if do_resize:
            resized_height, resized_width = smart_resize(
                num_frames=num_frames
891
892
                if num_frames > temporal_patch_size
                else temporal_patch_size,
893
894
895
896
897
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                max_pixels=max_image_pixels,
            )
898
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
899
        else:
900
            preprocessed_size = ImageSize(width=image_width, height=image_height)
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
916
917
918
        max_image_size, _ = self._get_vision_info(
            image_width=9999999, image_height=9999999
        )
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
        return max_image_size

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            max_image_pixels=28 * 28 * 2 * 6144,
        )
        return num_image_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            max_image_pixels=28 * 28 * 2 * 30000,
        )
        return num_video_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens or next_max_tokens == 0:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
985
986
987
988
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
989
990
991

        return max(max_frames_per_video, 1)

992
    def _get_video_second_idx_glm4v(
993
994
        self, metadata: dict[str, Any], total_frames: int
    ) -> list[int]:
995
996
        video_processor = self.get_video_processor()

997
        video_fps = metadata.get("fps", video_processor.fps)
998
999
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
1000
        duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
1001
1002
1003
        do_sample_frames = metadata["do_sample_frames"]
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
1004
        else:
1005
1006
            if duration <= video_processor.max_duration:
                n = int(math.floor(duration * video_processor.fps))
1007
                frame_indices = [
1008
1009
1010
                    min(
                        max_frame_idx,
                        int(math.ceil(i * video_fps / video_processor.fps)),
1011
1012
                    )
                    for i in range(n)
1013
                ]
1014
            else:
1015
                num_samples = int(video_processor.max_duration * video_processor.fps)
1016
1017
1018
                if num_samples >= meta_frames:
                    frame_indices = list(range(meta_frames))
                else:
1019
1020
1021
                    target_seconds = np.linspace(
                        0, duration, num_samples, endpoint=True
                    )
1022
1023
1024
1025
                    frame_indices = [
                        min(max_frame_idx, int(math.ceil(t * video_fps)))
                        for t in target_seconds
                    ]
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)
        if len(uniq) & 1:
            uniq.append(uniq[-1])
        frame_indices = uniq

        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(0, len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
    def _get_video_second_idx_glm46v(
        self, metadata: dict[str, Any], total_frames: int
    ) -> list[int]:
        video_processor = self.get_video_processor()

        video_fps = metadata["fps"]
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
        duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)

        do_sample_frames = metadata.get("do_sample_frames", True)
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
        else:
            DYNAMIC_FPS_THRES = {30: 3, 300: 1, 2400: 0.5}
            MAX_FRAME_COUNT_DYNAMIC = 640
            MAX_DURATION = 2400

            effective_duration = min(duration, MAX_DURATION)
            if effective_duration <= 30:
                target_fps = DYNAMIC_FPS_THRES[30]
            elif effective_duration <= 300:
                target_fps = DYNAMIC_FPS_THRES[300]
            else:
                target_fps = DYNAMIC_FPS_THRES[2400]

            temporal_patch_size = getattr(video_processor, "temporal_patch_size", 1)
            extract_t = int(effective_duration * target_fps * temporal_patch_size)
            extract_t = min(extract_t, MAX_FRAME_COUNT_DYNAMIC)

            duration_per_frame = 1 / video_fps
            timestamps = [i * duration_per_frame for i in range(meta_frames)]
            max_second = int(duration)

            if meta_frames < extract_t:
                frame_indices = np.linspace(
                    0, meta_frames - 1, extract_t, dtype=int
                ).tolist()
            else:
                frame_indices = []
                current_second = 0.0
                inv_fps = 1 / (temporal_patch_size * target_fps)
                for frame_index in range(meta_frames):
                    if timestamps[frame_index] >= current_second:
                        current_second += inv_fps
                        frame_indices.append(frame_index)
                        if current_second >= max_second:
                            break

            if len(frame_indices) < extract_t:
                if len(frame_indices) == 0:
                    start, end = 0, max(meta_frames - 1, 0)
                else:
                    start, end = frame_indices[0], frame_indices[-1]
                frame_indices = np.linspace(start, end, extract_t, dtype=int).tolist()
            elif len(frame_indices) > extract_t:
                frame_indices = np.linspace(
                    0, meta_frames - 1, extract_t, dtype=int
                ).tolist()

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)

        if len(uniq) & 1:
            uniq.append(uniq[-1])

        frame_indices = uniq
        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
    def _construct_video_placeholder(
        self,
        video_array: np.ndarray,
        metadata: dict[str, Any],
        grid_thw: torch.Tensor,
    ) -> str:
        hf_processor = self.get_hf_processor()
        tokenizer = self.get_tokenizer()
        image_processor = hf_processor.image_processor

        hf_config = self.get_hf_config()
        boi_token_id = hf_config.image_start_token_id
        eoi_token_id = hf_config.image_end_token_id
        bov_token_id = hf_config.video_start_token_id
        eov_token_id = hf_config.video_end_token_id
        merge_length = image_processor.merge_size**2

        assert isinstance(grid_thw, torch.Tensor)
1138
1139
1140
1141
1142
1143
1144
1145
1146
        timestamps = (
            self._get_video_second_idx_glm4v(metadata, len(video_array))
            if isinstance(hf_processor, Glm4vProcessor)
            else self._get_video_second_idx_glm46v(metadata, len(video_array))
        )

        timestamp_format = (
            "{}" if isinstance(hf_processor, Glm4vProcessor) else "{:.1f} seconds"
        )
1147
        frames_idx_token = [
1148
            tokenizer.encode(timestamp_format.format(i), add_special_tokens=False)
1149
1150
1151
1152
1153
1154
1155
1156
            for i in timestamps
        ]
        T, H, W = grid_thw
        num_tokens_per_frame = int(H * W) // merge_length
        placeholder = []
        placeholder.append(bov_token_id)
        for frame_idx in frames_idx_token:
            placeholder.append(boi_token_id)
1157
            placeholder.extend([hf_processor.video_token_id] * num_tokens_per_frame)
1158
1159
1160
1161
1162
1163
            placeholder.append(eoi_token_id)
            placeholder.extend(frame_idx)
        placeholder.append(eov_token_id)

        return placeholder

1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187

class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_config = self.info.get_hf_config()
        hf_processor = self.info.get_hf_processor()
        tokenizer = self.info.get_tokenizer()

        image_token: str = hf_processor.image_token
        video_token_ids = [
            hf_config.video_start_token_id,
            hf_processor.video_token_id,
            hf_config.video_end_token_id,
        ]
        video_token = tokenizer.decode(video_token_ids)

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1188
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1189
1190
1191
1192
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1193
        target_width, target_height = self.info.get_image_size_with_most_features()
1194
        target_num_frames = self.info.get_num_frames_with_most_features(
1195
1196
            seq_len, mm_counts
        )
1197
1198
1199
1200

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1201
        return {
1202
1203
1204
1205
1206
1207
1208
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1209
1210
1211
1212
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
1213
                overrides=video_overrides,
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
1224
        overrides: VideoDummyOptions | None = None,
1225
    ) -> list[VideoItem]:
1226
1227
1228
1229
1230
1231
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
1232
1233
1234
                        overrides.num_frames,
                        num_frames,
                    )
1235
1236
1237
1238
1239
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
1240
1241
1242
1243
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
1244
1245
1246
1247
1248
1249
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
1250
1251
1252
                        overrides.height,
                        height,
                    )
1253
                height = min(height, overrides.height)
1254

1255
        num_frames = max(num_frames, 2)  # GLM 4.6V requires 2 frames
1256
1257
1258
1259
1260
1261
1262
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
1263
                "frames_indices": [i for i in range(num_frames)],
1264
                "video_backend": "opencv",
1265
                "do_sample_frames": False,
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)

        return video_items


class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # GLM-4.1V use `image_token_id` as video placeholder, we need to
        # replace it with `video_token_id` for video processing. So we
        # separate video processing from image processing.
1290
1291
1292
1293
1294
        if (
            "videos" in mm_data
            and isinstance(mm_data["videos"], list)
            and len(mm_data["videos"]) > 0
        ):
1295
1296
1297
1298
1299
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            for item in mm_data.pop("videos", []):
                video_array, metadata = item

1300
1301
1302
                # don't update mm_kwargs inplace
                video_mm_kwargs = dict(**mm_kwargs)
                video_mm_kwargs["do_sample_frames"] = metadata.get(
1303
1304
                    "do_sample_frames", True
                )
1305
1306
1307

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
1308
1309

                unuse_metadata = ["do_sample_frames"]
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
                video_mm_data["video_metadata"] = [
                    [
                        VideoMetadata(
                            **{
                                k: metadata[k]
                                for k in metadata
                                if k not in unuse_metadata
                            }
                        )
                    ]
                ]
1321
1322
1323
1324

                video_outputs = super()._call_hf_processor(
                    prompt="<|begin_of_video|><|video|><|end_of_video|>",
                    mm_data=video_mm_data,
1325
                    mm_kwargs=video_mm_kwargs,
1326
1327
                    tok_kwargs=tok_kwargs,
                )
1328
1329
1330
1331
1332
                input_ids = video_outputs.pop("input_ids")
                input_ids[input_ids == processor.image_token_id] = (
                    processor.video_token_id
                )
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
1333
1334
1335
                prompt = prompt.replace(
                    "<|begin_of_video|><|video|><|end_of_video|>",
                    video_placeholder,
1336
                    1,
1337
1338
                )

1339
                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
1340
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1365
        return _create_qwen2vl_field_factory(
1366
1367
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)
1368
1369
1370
1371
1372

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1373
        out_mm_kwargs: MultiModalKwargsItems,
1374
1375
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1376
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1377
1378
1379
1380

        merge_length = image_processor.merge_size**2

        def get_image_replacement_glm4v(item_idx: int):
1381
1382
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
1383
1384
1385
1386
1387
1388
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_glm4v(item_idx: int):
1389
1390
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
1391
1392
1393
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
1394
            placeholder = self.info._construct_video_placeholder(
1395
1396
                video, metadata, grid_thw
            )
1397
1398
1399
1400
            return PromptUpdateDetails.select_token_id(
                placeholder,
                embed_token_id=hf_processor.video_token_id,
            )
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_glm4v,
            ),
            PromptReplacement(
                modality="video",
                target="<|begin_of_video|><|video|><|end_of_video|>",
                replacement=get_video_replacement_glm4v,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
1421
class Glm4vForConditionalGeneration(
1422
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1423
):
1424
1425
1426
1427
1428
1429
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1430
        "gate_up_proj": ["gate_up_proj"],
1431
1432
1433
1434
1435
1436
1437
1438
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
1439
1440
        }
    )
1441

1442
1443
    supports_encoder_tp_data = True

1444
    @classmethod
1445
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1446
1447
1448
1449
1450
1451
1452
        if modality.startswith("image"):
            return "<|begin_of_image|><|image|><|end_of_image|>"
        if modality.startswith("video"):
            return "<|begin_of_video|><|video|><|end_of_video|>"

        raise ValueError("Only image or video modality is supported")

1453
1454
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1455
        config = vllm_config.model_config.hf_config
1456
1457
1458
1459
1460
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1461
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1462
1463
1464
1465

        self.visual = Glm4vVisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-5),
1466
            quant_config=quant_config,
1467
            multimodal_config=multimodal_config,
1468
1469
1470
            prefix=maybe_prefix(prefix, "visual"),
        )

Yuxuan Zhang's avatar
Yuxuan Zhang committed
1471
1472
1473
1474
1475
1476
1477
        if config.model_type == "glm4v":
            architectures = ["Glm4ForCausalLM"]
        elif config.model_type == "glm4v_moe":
            architectures = ["Glm4MoeForCausalLM"]
        else:
            architectures = None

1478
1479
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1480
1481
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
1482
1483
            architectures=architectures,
        )
1484
1485

        self.make_empty_intermediate_tensors = (
1486
1487
            self.language_model.make_empty_intermediate_tensors
        )
1488
1489

    def _parse_and_validate_image_input(
1490
        self, **kwargs: object
1491
    ) -> Glm4vImageInputs | None:
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Glm4vImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Glm4vImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
1514
        self, **kwargs: object
1515
    ) -> Glm4vVideoInputs | None:
1516
1517
1518
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
1519

1520
1521
        if pixel_values_videos is None and video_embeds is None:
            return None
1522

1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
        if pixel_values_videos is not None:
            return Glm4vVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            return Glm4vVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
1538
1539
        self, image_input: Glm4vImageInputs
    ) -> tuple[torch.Tensor, ...]:
1540
1541
1542
1543
1544
1545
1546
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1547
            if self.use_data_parallel:
1548
1549
1550
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
                )
1551
            else:
1552
1553
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

1554
        merge_size = self.visual.spatial_merge_size
1555
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1556
        return image_embeds.split(sizes)
1557
1558

    def _process_video_input(
1559
1560
        self, video_input: Glm4vVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1561
1562
1563
1564
1565
1566
1567
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1568
1569
                self.visual.dtype
            )
1570
            if self.use_data_parallel:
1571
1572
1573
1574
1575
1576
                return run_dp_sharded_mrope_vision_model(
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
                )
1577
            else:
1578
1579
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

1580
1581
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1582
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
1583
        return video_embeds.split(sizes)
1584
1585
1586
1587
1588
1589
1590

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1605
1606
1607
1608
1609
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1610
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
1611
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1612
1613
1614
1615
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1616
        # tensor corresponding to a multimodal data item (image or video).
1617
1618
1619
1620
1621
1622
1623
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1624
1625
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
1626
1627
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1628
                multimodal_embeddings += tuple(video_embeddings)
1629
1630
        return multimodal_embeddings

1631
    def get_mrope_input_positions(
1632
        self,
1633
        input_tokens: list[int],
1634
        mm_features: list[MultiModalFeatureSpec],
1635
    ) -> tuple[torch.Tensor, int]:
1636
1637
1638
1639
1640
1641
        kwargs = MultiModalFeatureSpec.gather_kwargs(
            mm_features,
            {"image_grid_thw", "video_grid_thw"},
        )
        image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
        video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
1642

1643
        hf_config = self.config
1644
1645
1646
1647
1648
1649
        image_token_id = hf_config.image_token_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        llm_pos_ids_list: list = []

1650
        if image_grid_thw or video_grid_thw:
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                enumerate(input_token_type), lambda x: x[1]
            ):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = (
                    llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                )
                if modality_type == "image":
1682
                    t, h, w = image_grid_thw[mm_data_idx]
1683
1684
1685
1686
1687
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t,
                        h // spatial_merge_size,
                        w // spatial_merge_size,
                    )
1688

1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
                    t_index = (
                        torch.arange(llm_grid_t)
                        .view(-1, 1)
                        .expand(-1, llm_grid_h * llm_grid_w)
                        .flatten()
                    )
                    h_index = (
                        torch.arange(llm_grid_h)
                        .view(1, -1, 1)
                        .expand(llm_grid_t, -1, llm_grid_w)
                        .flatten()
                    )
                    w_index = (
                        torch.arange(llm_grid_w)
                        .view(1, 1, -1)
                        .expand(llm_grid_t, llm_grid_h, -1)
                        .flatten()
                    )
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx
                    )
                    mm_data_idx += 1
1711

1712
1713
1714
                elif modality_type == "video":
                    t, h, w = (
                        video_frame_num,
1715
                        *image_grid_thw[mm_data_idx][1:],
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = (
                        t,
                        h // spatial_merge_size,
                        w // spatial_merge_size,
                    )

                    for t_idx in range(llm_grid_t):
                        t_index = (
                            torch.tensor(t_idx)
                            .view(-1, 1)
                            .expand(-1, llm_grid_h * llm_grid_w)
                            .flatten()
                        )
                        h_index = (
                            torch.arange(llm_grid_h)
                            .view(1, -1, 1)
                            .expand(1, -1, llm_grid_w)
                            .flatten()
                        )
                        w_index = (
                            torch.arange(llm_grid_w)
                            .view(1, 1, -1)
                            .expand(1, llm_grid_h, -1)
                            .flatten()
                        )
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx
                        )

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
                    )
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return llm_positions, mrope_position_delta
1763
1764
1765
1766
1767

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1768
1769
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1770
        **kwargs: object,
1771
    ) -> torch.Tensor | IntermediateTensors:
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
        """Run forward pass for GLM-4V.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for GLM-4V
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
1782
1783
1784
1785
            intermediate_tensors: Optional intermediate tensors for pipeline
                parallelism.
            inputs_embeds: Optional pre-computed input embeddings.
            **kwargs: Additional keyword arguments.
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1801
    ) -> torch.Tensor | None:
1802
        return self.language_model.compute_logits(hidden_states)
1803

1804
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1805
1806
1807
1808
1809
1810
1811
1812
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
Jee Jee Li's avatar
Jee Jee Li committed
1813
            language_model="language_model.model",
1814
1815
1816
            connector="visual.merger.",
            tower_model="visual.",
        )
Jee Jee Li's avatar
Jee Jee Li committed
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }