glm4_1v.py 59.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The ZhipuAI Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""

import math
30
from collections.abc import Callable, Iterable, Mapping, Sequence
31
from functools import partial
32
from typing import Annotated, Any, Literal, TypeAlias
33
34
35
36
37
38
39

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
Yuxuan Zhang's avatar
Yuxuan Zhang committed
40
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
41
from transformers.models.glm4v.image_processing_glm4v import (
42
43
44
45
    Glm4vImageProcessor,
    smart_resize,
)
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
46
47
from transformers.video_utils import VideoMetadata

48
from vllm.attention.backends.registry import _Backend
49
50
51
52
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
53
from vllm.config import VllmConfig
54
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
55
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
56
57
58
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
59
60
61
62
63
64
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
65
66
67
68
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
83
84
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
85
from vllm.utils.tensor_schema import TensorSchema, TensorShape
86
87

from ..layers.activation import SiluAndMul
88
89
90
91
92
93
94
95
96
97
98
99
100
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
101
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
102
103
104
105
106
107
108
109
110

logger = init_logger(__name__)

# For profile run
_MAX_FRAMES_PER_VIDEO = 600

# === Vision Inputs === #


111
class Glm4vImagePixelInputs(TensorSchema):
112
    """
113
114
115
116
117
    Dimensions:
        - np: Number of patches
        - cpp: Number of channels * patch_size * patch_size
        - ni: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
118
    """
119

120
    type: Literal["pixel_values"] = "pixel_values"
121

122
123
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
124
125


126
class Glm4vImageEmbeddingInputs(TensorSchema):
127
    """
128
129
130
131
132
    Dimensions:
        - f: Number of image features (varies based on image resolution)
        - h: Hidden size (must match language model backbone)
        - n: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
133
    """
134

135
136
137
138
    type: Literal["image_embeds"] = "image_embeds"

    image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
139
140


141
Glm4vImageInputs: TypeAlias = Glm4vImagePixelInputs | Glm4vImageEmbeddingInputs
142
143


144
class Glm4vVideoPixelInputs(TensorSchema):
145
    """
146
147
148
149
150
    Dimensions:
        - np: Number of patches
        - ctpp: Number of channels * temporal_patch_size *
            patch_size * patch_size
        - f: Number of frames
151
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
152
          video, grid_h, grid_w)
153
    """
154

155
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
156

157
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
158
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
159
160


161
class Glm4vVideoEmbeddingInputs(TensorSchema):
162
    """
163
164
165
    Dimensions:
        - p: Number of video patches across all frames
        - h: Hidden size (must match language model backbone)
166
        - f: Number of frames
167
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
168
          video, grid_h, grid_w)
169
    """
170

171
    type: Literal["video_embeds"] = "video_embeds"
172

173
    video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
174
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
175
176


177
Glm4vVideoInputs: TypeAlias = Glm4vVideoPixelInputs | Glm4vVideoEmbeddingInputs
178

179
# ==== Vision Encoder ==== #
180
181
182
183
184
185
186
187


class Glm4vVisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
188
        quant_config: QuantizationConfig | None = None,
189
        prefix: str = "",
190
        use_data_parallel: bool = False,
191
192
    ):
        super().__init__()
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=in_features,
            output_sizes=[hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist

    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
    dist.all_gather(
        gathered_tensors,
        local_tensor,
        group=parallel_state.get_tp_group().device_group,
    )

    gathered_tensors_split = [
230
        torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Glm4vVisionAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
245
        quant_config: QuantizationConfig | None = None,
246
        prefix: str = "",
247
        use_data_parallel: bool = False,
248
        attn_backend_override: _Backend | None = None,
249
250
251
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
252
253
254
255
256
257
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = (
            0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
        )
258
        self.hidden_size_per_attention_head = dist_utils.divide(
259
260
            projection_size, num_heads
        )
261
        self.num_attention_heads_per_partition = dist_utils.divide(
262
263
            num_heads, self.tp_size
        )
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=False,
            quant_config=quant_config,
            # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
            prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            bias=False,
            disable_tp=use_data_parallel,
        )
284
285

        # Detect attention implementation.
286
287
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
288
            dtype=torch.get_default_dtype(),
289
            attn_backend_override=attn_backend_override,
290
        )
291
        self.use_upstream_fa = False
292

293
294
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
295
296
                self.attn_backend,
                self.use_upstream_fa,
297
                attn_backend_override=attn_backend_override,
298
            )
299
        )
300

301
        if self.attn_backend not in {
302
303
304
305
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
306
307
        }:
            raise RuntimeError(
308
309
                f"GLM-4V does not support {self.attn_backend} backend now."
            )
310

311
        self.is_flash_attn_backend = self.attn_backend in {
312
313
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
314
315
        }

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
334
335
336
337
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
338
339
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
340
341
342
343
344
345
346
347
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
        batch_size = q.shape[1]

348
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
349
        if rotary_pos_emb is not None:
350
351
352
353
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
354

355
        if self.is_flash_attn_backend:
356
357
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

358
            output = self.flash_attn_varlen_func(
359
360
361
362
363
364
365
366
367
368
369
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0,
                causal=False,
            )

370
371
372
            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
373
374
375
376
377
378
379
380
381
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
382
383
384
385
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
386
387
388
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
389
390
391
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
392
393
394
395
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

396
397
398
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
399
400

            context_layer = xops.memory_efficient_attention_forward(
401
402
403
404
405
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
406
407
408
409
410
411
412
413
414
415
416

        output, _ = self.proj(context_layer)
        return output


class Glm4vVisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
417
418
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
419
        prefix: str = "",
420
        use_data_parallel: bool = False,
421
        attn_backend_override: _Backend | None = None,
422
423
424
425
426
427
428
429
430
431
432
433
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Glm4vVisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
434
            use_data_parallel=use_data_parallel,
435
            attn_backend_override=attn_backend_override,
436
437
438
439
440
441
        )
        self.mlp = Glm4vVisionMLP(
            dim,
            mlp_hidden_dim,
            bias=False,
            quant_config=quant_config,
442
            prefix=f"{prefix}.mlp",
443
            use_data_parallel=use_data_parallel,
444
445
446
        )

    def forward(
447
448
449
450
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
451
452
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
453
    ) -> torch.Tensor:
454
        x_attn = self.attn(
455
456
457
458
459
460
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
461
462
        x_fused_norm, residual = self.norm2(x, residual=x_attn)
        x = residual + self.mlp(x_fused_norm)
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

        return x


class Glm4vVisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 1,
        in_channels: int = 3,
        hidden_size: int = 1536,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
491
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
492
493
494
495
496
497
498
499
500
        x = self.proj(x).view(L, self.hidden_size)
        return x


class Glm4vPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
501
        quant_config: QuantizationConfig | None = None,
502
        bias: bool = False,
503
        prefix: str = "",
504
        use_data_parallel: bool = False,
505
506
507
    ) -> None:
        super().__init__()
        self.hidden_size = d_model
508
509
510
511
512
513
514
515
516
        self.proj = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=bias,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
517
        self.post_projection_norm = nn.LayerNorm(self.hidden_size)
518
        self.gate_up_proj = MergedColumnParallelLinear(
519
520
521
522
            input_size=self.hidden_size,
            output_sizes=[context_dim] * 2,
            bias=bias,
            quant_config=quant_config,
523
            prefix=f"{prefix}.gate_up_proj",
524
            disable_tp=use_data_parallel,
525
        )
526
        self.down_proj = RowParallelLinear(
527
528
529
530
            context_dim,
            self.hidden_size,
            bias=bias,
            quant_config=quant_config,
531
            prefix=f"{prefix}.down_proj",
532
            disable_tp=use_data_parallel,
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        )
        self.act_fn = SiluAndMul()
        self.extra_activation_func = nn.GELU()

    def forward(self, x: torch.Tensor):
        x, _ = self.proj(x)
        x = self.extra_activation_func(self.post_projection_norm(x))
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Glm4vVisionEmbeddings(nn.Module):
    def __init__(self, config: Glm4vVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

554
        self.num_patches = (self.image_size // self.patch_size) ** 2
555
        self.num_positions = self.num_patches
556
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
557
558
559
560
561
562
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

563
564
565
    def forward(
        self, embeddings, lengths, image_shapes, h_coords, w_coords
    ) -> torch.Tensor:
566
567
568
569
570
571
572
573
574
575
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
576
577
578
            adapted_pos_embed = torch.empty(
                0, hidden_size, device=device, dtype=pos_embed_weight.dtype
            )
579
580
581
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
582
                lengths = torch.tensor(lengths, device=device, dtype=torch.long)
583
            if not isinstance(image_shapes, torch.Tensor):
584
585
586
                image_shapes = torch.tensor(
                    image_shapes, device=device, dtype=torch.long
                )
587
588
589
590

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
591
592
593
594
595
596
            pos_embed_2d = (
                pos_embed_weight.view(orig_size, orig_size, hidden_size)
                .permute(2, 0, 1)
                .unsqueeze(0)
                .to(device=device, dtype=torch.float32)
            )
597
598

            # Calculate target dimensions for each patch
599
600
601
602
603
604
605
606
607
608
            # Add bounds checking for data parallel mode
            if len(lengths) > image_shapes.shape[0]:
                # In data parallel mode, some GPUs might not have all
                # image shapes
                # Use available image shapes, cycling if necessary
                target_h_list = []
                target_w_list = []
                for i in range(len(lengths)):
                    # Cycle through available shapes
                    shape_idx = i % image_shapes.shape[0]
609
610
611
612
613
614
615
616
                    target_h_list.append(image_shapes[shape_idx, 1].repeat(lengths[i]))
                    target_w_list.append(image_shapes[shape_idx, 2].repeat(lengths[i]))
                target_h = torch.cat(target_h_list).to(
                    device=device, dtype=torch.float32
                )
                target_w = torch.cat(target_w_list).to(
                    device=device, dtype=torch.float32
                )
617
            else:
618
619
620
621
622
623
                target_h = torch.cat(
                    [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
                target_w = torch.cat(
                    [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
624
625
626
627
628
629
630
631

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
632
            grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
633
634
635
636
637
638
639
640
641
642
643
644

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(
                pos_embed_2d,
                grid,
                mode="bicubic",
                align_corners=False,
                padding_mode="border",
            )

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = (
645
646
647
648
649
                interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
            )
            adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
                embeddings.device
            )
650
651
652
653
654
655
656
657
658
659
660

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class Glm4vVisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
661
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
662
663
664
665
666
667
668
669
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0,
                        self.dim,
                        2,
                        dtype=torch.float,
                        device=self.inv_freq.device,
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
686
687
688
689
690
691
692
693
694
695
696
697
698
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Glm4vVisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Glm4vVisionConfig,
        norm_eps: float = 1e-6,
699
        quant_config: QuantizationConfig | None = None,
700
        prefix: str = "",
701
        use_data_parallel: bool = False,
702
        attn_backend_override: _Backend | None = None,
703
704
705
706
707
708
709
710
711
    ) -> None:
        super().__init__()

        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
712
        self.use_data_parallel = use_data_parallel
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.out_hidden_size = vision_config.out_hidden_size

        self.patch_embed = Glm4vVisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
728
729
730
731
732
733
734
735
736
737
        self.blocks = nn.ModuleList(
            [
                Glm4vVisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.out_hidden_size,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=self.use_data_parallel,
738
                    attn_backend_override=attn_backend_override,
739
740
741
742
                )
                for layer_idx in range(depth)
            ]
        )
743
744
745
746
747
        self.merger = Glm4vPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=vision_config.intermediate_size,
            quant_config=quant_config,
            bias=False,
748
            prefix=f"{prefix}.merger",
749
            use_data_parallel=self.use_data_parallel,
750
751
752
        )
        self.embeddings = Glm4vVisionEmbeddings(vision_config)

753
754
755
        self.post_conv_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
756
757
758
759
760
761
        self.downsample = nn.Conv2d(
            in_channels=vision_config.hidden_size,
            out_channels=vision_config.out_hidden_size,
            kernel_size=vision_config.spatial_merge_size,
            stride=vision_config.spatial_merge_size,
        )
762
763
764
        self.post_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
765

766
        self.attn_backend = get_vit_attn_backend(
767
768
769
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
770
771
772
773
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
774
            self.attn_backend = _Backend.FLASH_ATTN
775
776
777
778
779
780
781
782
783
784
785
786
787
788

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
810
811
812
813
814
815
816
817
818
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb, pos_ids

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
819
    ) -> tuple[int | None, list[int] | None]:
820
821
        max_seqlen, seqlens = None, None
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
822
823
824
825
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
826
827
828
829
830
831
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        return max_seqlen, seqlens

    def forward(
        self,
        x: torch.Tensor,
832
        grid_thw: list[list[int]],
833
    ) -> torch.Tensor:
834
835
836
        # Convert grid_thw to tensor (always expecting list format now)
        grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

837
838
839
840
841
842
843
844
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)
        x = self.post_conv_layernorm(x)

        # compute position embedding
        rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
        # compute cu_seqlens
845
846
847
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
848
849
850
851
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
852
853
854
        x = self.embeddings(
            x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
        )
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )

        # adapter
        x = self.post_layernorm(x)

870
        x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
871
872
873
874
875
876
        x = x.permute(0, 3, 1, 2)
        x = self.downsample(x).view(-1, self.out_hidden_size)
        x = self.merger(x)

        return x

877
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
901
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
902
903
904
905
906
907
908
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Glm4vProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
909
        return self.ctx.get_hf_config()
910
911
912
913

    def get_tokenizer(self):
        return self.ctx.tokenizer

914
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
915
916
        return {"image": None, "video": 1}

917
918
    def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
919

920
921
    def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 16,
        do_resize: bool = True,
        max_image_pixels: int = 28 * 28 * 2 * 30000,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
        if do_resize:
            resized_height, resized_width = smart_resize(
                num_frames=num_frames
940
941
                if num_frames > temporal_patch_size
                else temporal_patch_size,
942
943
944
945
946
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                max_pixels=max_image_pixels,
            )
947
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
948
        else:
949
            preprocessed_size = ImageSize(width=image_width, height=image_height)
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
965
966
967
        max_image_size, _ = self._get_vision_info(
            image_width=9999999, image_height=9999999
        )
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        return max_image_size

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            max_image_pixels=28 * 28 * 2 * 6144,
        )
        return num_image_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            max_image_pixels=28 * 28 * 2 * 30000,
        )
        return num_video_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens or next_max_tokens == 0:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
1034
1035
1036
1037
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
1038
1039
1040

        return max(max_frames_per_video, 1)

1041
1042
1043
    def _get_video_second_idx(
        self, metadata: dict[str, Any], total_frames: int
    ) -> list[int]:
1044
1045
        video_processor = self.get_video_processor()

1046
        video_fps = metadata.get("fps", video_processor.fps)
1047
1048
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
1049
        duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
1050
1051
1052
        do_sample_frames = metadata["do_sample_frames"]
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
1053
        else:
1054
1055
            if duration <= video_processor.max_duration:
                n = int(math.floor(duration * video_processor.fps))
1056
                frame_indices = [
1057
1058
1059
                    min(
                        max_frame_idx,
                        int(math.ceil(i * video_fps / video_processor.fps)),
1060
1061
                    )
                    for i in range(n)
1062
                ]
1063
            else:
1064
                num_samples = int(video_processor.max_duration * video_processor.fps)
1065
1066
1067
                if num_samples >= meta_frames:
                    frame_indices = list(range(meta_frames))
                else:
1068
1069
1070
                    target_seconds = np.linspace(
                        0, duration, num_samples, endpoint=True
                    )
1071
1072
1073
1074
                    frame_indices = [
                        min(max_frame_idx, int(math.ceil(t * video_fps)))
                        for t in target_seconds
                    ]
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)
        if len(uniq) & 1:
            uniq.append(uniq[-1])
        frame_indices = uniq

        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(0, len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    def _construct_video_placeholder(
        self,
        video_array: np.ndarray,
        metadata: dict[str, Any],
        grid_thw: torch.Tensor,
    ) -> str:
        hf_processor = self.get_hf_processor()
        tokenizer = self.get_tokenizer()
        image_processor = hf_processor.image_processor

        hf_config = self.get_hf_config()
        boi_token_id = hf_config.image_start_token_id
        eoi_token_id = hf_config.image_end_token_id
        bov_token_id = hf_config.video_start_token_id
        eov_token_id = hf_config.video_end_token_id
        merge_length = image_processor.merge_size**2

        assert isinstance(grid_thw, torch.Tensor)
        timestamps = self._get_video_second_idx(metadata, len(video_array))
        frames_idx_token = [
1112
            tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps
1113
1114
1115
1116
1117
1118
1119
        ]
        T, H, W = grid_thw
        num_tokens_per_frame = int(H * W) // merge_length
        placeholder = []
        placeholder.append(bov_token_id)
        for frame_idx in frames_idx_token:
            placeholder.append(boi_token_id)
1120
            placeholder.extend([hf_processor.video_token_id] * num_tokens_per_frame)
1121
1122
1123
1124
1125
1126
            placeholder.append(eoi_token_id)
            placeholder.extend(frame_idx)
        placeholder.append(eov_token_id)

        return placeholder

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150

class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_config = self.info.get_hf_config()
        hf_processor = self.info.get_hf_processor()
        tokenizer = self.info.get_tokenizer()

        image_token: str = hf_processor.image_token
        video_token_ids = [
            hf_config.video_start_token_id,
            hf_processor.video_token_id,
            hf_config.video_end_token_id,
        ]
        video_token = tokenizer.decode(video_token_ids)

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1151
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1152
1153
1154
1155
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1156
        target_width, target_height = self.info.get_image_size_with_most_features()
1157
        target_num_frames = self.info.get_num_frames_with_most_features(
1158
1159
            seq_len, mm_counts
        )
1160
1161
1162
1163

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1164
        return {
1165
1166
1167
1168
1169
1170
1171
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1172
1173
1174
1175
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
1176
                overrides=video_overrides,
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
1187
        overrides: VideoDummyOptions | None = None,
1188
    ) -> list[VideoItem]:
1189
1190
1191
1192
1193
1194
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
1195
1196
1197
                        overrides.num_frames,
                        num_frames,
                    )
1198
1199
1200
1201
1202
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
1203
1204
1205
1206
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
1207
1208
1209
1210
1211
1212
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
1213
1214
1215
                        overrides.height,
                        height,
                    )
1216
                height = min(height, overrides.height)
1217

1218
1219
1220
1221
1222
1223
1224
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
1225
                "frames_indices": [i for i in range(num_frames)],
1226
                "video_backend": "opencv",
1227
                "do_sample_frames": False,
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)

        return video_items


class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # GLM-4.1V use `image_token_id` as video placeholder, we need to
        # replace it with `video_token_id` for video processing. So we
        # separate video processing from image processing.
1252
1253
1254
1255
1256
        if (
            "videos" in mm_data
            and isinstance(mm_data["videos"], list)
            and len(mm_data["videos"]) > 0
        ):
1257
1258
1259
1260
1261
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            for item in mm_data.pop("videos", []):
                video_array, metadata = item

1262
1263
1264
                # don't update mm_kwargs inplace
                video_mm_kwargs = dict(**mm_kwargs)
                video_mm_kwargs["do_sample_frames"] = metadata.get(
1265
1266
                    "do_sample_frames", True
                )
1267
1268
1269

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
1270
1271

                unuse_metadata = ["do_sample_frames"]
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
                video_mm_data["video_metadata"] = [
                    [
                        VideoMetadata(
                            **{
                                k: metadata[k]
                                for k in metadata
                                if k not in unuse_metadata
                            }
                        )
                    ]
                ]
1283
1284
1285
1286

                video_outputs = super()._call_hf_processor(
                    prompt="<|begin_of_video|><|video|><|end_of_video|>",
                    mm_data=video_mm_data,
1287
                    mm_kwargs=video_mm_kwargs,
1288
1289
                    tok_kwargs=tok_kwargs,
                )
1290
1291
1292
1293
1294
                input_ids = video_outputs.pop("input_ids")
                input_ids[input_ids == processor.image_token_id] = (
                    processor.video_token_id
                )
                video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
1295
1296
1297
                prompt = prompt.replace(
                    "<|begin_of_video|><|video|><|end_of_video|>",
                    video_placeholder,
1298
                    1,
1299
1300
                )

1301
                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
1302
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1327
        return _create_qwen2vl_field_factory(
1328
1329
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)
1330
1331
1332
1333
1334

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1335
        out_mm_kwargs: MultiModalKwargsItems,
1336
1337
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1338
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1339
1340
1341
1342

        merge_length = image_processor.merge_size**2

        def get_image_replacement_glm4v(item_idx: int):
1343
1344
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
1345
1346
1347
1348
1349
1350
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_glm4v(item_idx: int):
1351
1352
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
1353
1354
1355
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
1356
            placeholder = self.info._construct_video_placeholder(
1357
1358
                video, metadata, grid_thw
            )
1359
1360
1361
1362
            return PromptUpdateDetails.select_token_id(
                placeholder,
                embed_token_id=hf_processor.video_token_id,
            )
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_glm4v,
            ),
            PromptReplacement(
                modality="video",
                target="<|begin_of_video|><|video|><|end_of_video|>",
                replacement=get_video_replacement_glm4v,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
1383
1384
1385
class Glm4vForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
1386
1387
    merge_by_field_config = True

1388
1389
1390
1391
1392
1393
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1394
        "gate_up_proj": ["gate_up_proj"],
1395
1396
1397
1398
1399
1400
1401
1402
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
1403
1404
        }
    )
1405

1406
1407
    supports_encoder_tp_data = True

1408
    @classmethod
1409
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1410
1411
1412
1413
1414
1415
1416
        if modality.startswith("image"):
            return "<|begin_of_image|><|image|><|end_of_image|>"
        if modality.startswith("video"):
            return "<|begin_of_video|><|video|><|end_of_video|>"

        raise ValueError("Only image or video modality is supported")

1417
1418
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1419
        config = vllm_config.model_config.hf_config
1420
1421
1422
1423
1424
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1425
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1426

1427
1428
1429
1430
1431
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
1432
1433
1434
        self.visual = Glm4vVisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-5),
1435
            quant_config=quant_config,
1436
            prefix=maybe_prefix(prefix, "visual"),
1437
            use_data_parallel=self.use_data_parallel,
1438
            attn_backend_override=attn_backend_override,
1439
1440
        )

Yuxuan Zhang's avatar
Yuxuan Zhang committed
1441
1442
1443
1444
1445
1446
1447
        if config.model_type == "glm4v":
            architectures = ["Glm4ForCausalLM"]
        elif config.model_type == "glm4v_moe":
            architectures = ["Glm4MoeForCausalLM"]
        else:
            architectures = None

1448
1449
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1450
1451
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
1452
1453
            architectures=architectures,
        )
1454
1455

        self.make_empty_intermediate_tensors = (
1456
1457
            self.language_model.make_empty_intermediate_tensors
        )
1458
1459

    def _parse_and_validate_image_input(
1460
        self, **kwargs: object
1461
    ) -> Glm4vImageInputs | None:
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Glm4vImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Glm4vImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
1484
        self, **kwargs: object
1485
    ) -> Glm4vVideoInputs | None:
1486
1487
1488
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
1489

1490
1491
        if pixel_values_videos is None and video_embeds is None:
            return None
1492

1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
        if pixel_values_videos is not None:
            return Glm4vVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            return Glm4vVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
1508
1509
        self, image_input: Glm4vImageInputs
    ) -> tuple[torch.Tensor, ...]:
1510
1511
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1512
        grid_thw_list = grid_thw.tolist()
1513
1514
1515
1516
1517

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1518
            if self.use_data_parallel:
1519
1520
1521
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
                )
1522
            else:
1523
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
1524
        merge_size = self.visual.spatial_merge_size
1525
1526
1527
1528
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1529
        return image_embeds.split(sizes)
1530
1531

    def _process_video_input(
1532
1533
        self, video_input: Glm4vVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1534
1535
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1536
        grid_thw_list = grid_thw.tolist()
1537
1538
1539
1540
1541

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1542
1543
                self.visual.dtype
            )
1544
            if self.use_data_parallel:
1545
1546
1547
1548
1549
1550
                return run_dp_sharded_mrope_vision_model(
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
                )
1551
            else:
1552
1553
1554
                video_embeds = self.visual(
                    pixel_values_videos, grid_thw=grid_thw.tolist()
                )
1555
1556
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1557
1558
1559
1560
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1561
        return video_embeds.split(sizes)
1562
1563
1564
1565
1566
1567
1568

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1583
1584
1585
1586
1587
1588
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
1589
        self, **kwargs: object
1590
    ) -> MultiModalEmbeddings | None:
1591
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1592
1593
1594
1595
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1596
        # tensor corresponding to a multimodal data item (image or video).
1597
1598
1599
1600
1601
1602
1603
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1604
1605
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
1606
1607
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1608
                multimodal_embeddings += tuple(video_embeddings)
1609
1610
1611
1612
1613
1614
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1615
1616
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1617
        **kwargs: object,
1618
    ) -> torch.Tensor | IntermediateTensors:
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
        """Run forward pass for GLM-4V.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for GLM-4V
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
1629
1630
1631
1632
            intermediate_tensors: Optional intermediate tensors for pipeline
                parallelism.
            inputs_embeds: Optional pre-computed input embeddings.
            **kwargs: Additional keyword arguments.
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1648
    ) -> torch.Tensor | None:
1649
        return self.language_model.compute_logits(hidden_states)
1650

1651
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1652
1653
1654
1655
1656
1657
1658
1659
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
Jee Jee Li's avatar
Jee Jee Li committed
1660
            language_model="language_model.model",
1661
1662
1663
            connector="visual.merger.",
            tower_model="visual.",
        )
Jee Jee Li's avatar
Jee Jee Li committed
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }