glm4_1v.py 63.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The ZhipuAI Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""

zhuwenwen's avatar
zhuwenwen committed
29
import os
30
31
32
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
33
from typing import Annotated, Any, Callable, Literal, Optional, Union
34
35
36
37
38
39

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
40
from packaging.version import Version
41
from transformers import BatchFeature
42
from transformers import __version__ as TRANSFORMERS_VERSION
Yuxuan Zhang's avatar
Yuxuan Zhang committed
43
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
44
45
46
47
48
49
from transformers.models.glm4v.image_processing_glm4v import (
    Glm4vImageProcessor, smart_resize)
from transformers.models.glm4v.video_processing_glm4v import (
    Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata

50
from vllm.attention.layer import check_upstream_fa_availability
51
from vllm.config import VllmConfig
52
53
from vllm.distributed import (get_tensor_model_parallel_world_size,
                              parallel_state)
54
55
56
57
58
59
60
61
62
63
64
65
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
66
                                    MultiModalKwargsItems, VideoItem)
67
68
69
70
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
                                   MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
71
                                        PromptUpdate, PromptUpdateDetails)
72
73
74
75
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
76
from vllm.utils.tensor_schema import TensorSchema, TensorShape
77
78
79
80

from ..layers.activation import SiluAndMul
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
                         SupportsMultiModal, SupportsPP)
81
82
from .qwen2_vl import (_create_qwen2vl_field_factory,
                       apply_rotary_pos_emb_vision)
83
84
85
from .utils import (AutoWeightsLoader, WeightsMapper,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
86
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
87
88
89
90
91
92
93
94
95

logger = init_logger(__name__)

# For profile run
_MAX_FRAMES_PER_VIDEO = 600

# === Vision Inputs === #


96
class Glm4vImagePixelInputs(TensorSchema):
97
    """
98
99
100
101
102
    Dimensions:
        - np: Number of patches
        - cpp: Number of channels * patch_size * patch_size
        - ni: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
103
    """
104
    type: Literal["pixel_values"] = "pixel_values"
105

106
107
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
108
109


110
class Glm4vImageEmbeddingInputs(TensorSchema):
111
    """
112
113
114
115
116
    Dimensions:
        - f: Number of image features (varies based on image resolution)
        - h: Hidden size (must match language model backbone)
        - n: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
117
    """
118
119
120
121
    type: Literal["image_embeds"] = "image_embeds"

    image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
122
123
124
125
126


Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs]


127
class Glm4vVideoPixelInputs(TensorSchema):
128
    """
129
130
131
132
133
    Dimensions:
        - np: Number of patches
        - ctpp: Number of channels * temporal_patch_size *
            patch_size * patch_size
        - f: Number of frames
134
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
135
          video, grid_h, grid_w)
136
    """
137
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
138

139
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
140
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
141
142


143
class Glm4vVideoEmbeddingInputs(TensorSchema):
144
    """
145
146
147
    Dimensions:
        - p: Number of video patches across all frames
        - h: Hidden size (must match language model backbone)
148
        - f: Number of frames
149
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
150
          video, grid_h, grid_w)
151
    """
152
    type: Literal["video_embeds"] = "video_embeds"
153

154
    video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
155
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
156
157
158
159


Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]

160
# ==== Vision Encoder ==== #
161
162
163
164
165
166
167
168
169
170


class Glm4vVisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
171
        prefix: str = "",
172
        use_data_parallel: bool = False,
173
174
    ):
        super().__init__()
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=in_features,
            output_sizes=[hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist

    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
    dist.all_gather(
        gathered_tensors,
        local_tensor,
        group=parallel_state.get_tp_group().device_group,
    )

    gathered_tensors_split = [
        torch.split(tensor, hidden_size // tp_size, -1)
        for tensor in gathered_tensors
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Glm4vVisionAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
231
        use_data_parallel: bool = False,
232
233
234
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
235
236
        self.tp_size = (1 if use_data_parallel else
                        get_tensor_model_parallel_world_size())
237
238
        self.tp_rank = (0 if use_data_parallel else
                        parallel_state.get_tensor_model_parallel_rank())
239
240
241
242
243
        self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
            num_heads, self.tp_size)

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=False,
            quant_config=quant_config,
            # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
            prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            bias=False,
            disable_tp=use_data_parallel,
        )
263
264

        # Detect attention implementation.
265
266
267
268
269
270
271
272
273
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
            dtype=torch.get_default_dtype())
        self.use_upstream_fa = False
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
            self.use_upstream_fa = True

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if self.attn_backend not in {
                _Backend.FLASH_ATTN,
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
        }:
            raise RuntimeError(
                f"GLM-4V does not support {self.attn_backend} backend now.")

    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
        batch_size = q.shape[1]

        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
                   for x in (q, k, v))
        if rotary_pos_emb is not None:
317
318
319
320
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
321
322
323
324

        if self.attn_backend == _Backend.FLASH_ATTN:
            # from vllm_flash_attn.flash_attn_interface import (
            #   flash_attn_varlen_func)
325
326
327
328
            # if self.use_upstream_fa:
            #     from flash_attn import flash_attn_varlen_func
            # else:
            #     from vllm.vllm_flash_attn import flash_attn_varlen_func
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            from flash_attn import flash_attn_varlen_func

            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

            output = flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0,
                causal=False,
            )

            context_layer = rearrange(output,
346
347
                                      "(b s) h d -> s b (h d)",
                                      b=batch_size).contiguous()
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
                                 for x in [q_i, k_i, v_i])
                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
366
367
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
368
369
370
371
372
373
374
375
376
377
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

            attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
                                                       kv_seqlen=None,
                                                       device=q.device)

            context_layer = xops.memory_efficient_attention_forward(
                q, k, v, attn_bias=attn_bias, p=0, scale=None)
378
379
            context_layer = rearrange(context_layer,
                                      "b s h d -> s b (h d)").contiguous()
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

        output, _ = self.proj(context_layer)
        return output


class Glm4vVisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
395
        use_data_parallel: bool = False,
396
397
398
399
400
401
402
403
404
405
406
407
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Glm4vVisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
408
            use_data_parallel=use_data_parallel,
409
410
411
412
413
414
        )
        self.mlp = Glm4vVisionMLP(
            dim,
            mlp_hidden_dim,
            bias=False,
            quant_config=quant_config,
415
            prefix=f"{prefix}.mlp",
416
            use_data_parallel=use_data_parallel,
417
418
419
420
421
422
423
424
425
426
        )

    def forward(
            self,
            x: torch.Tensor,
            cu_seqlens: torch.Tensor,
            rotary_pos_emb: torch.Tensor,
            max_seqlen: Optional[int] = None,  # Only used for Flash Attention
            seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
427
        x_attn = self.attn(
428
429
430
431
432
433
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
434
435
        x_fused_norm, residual = self.norm2(x, residual=x_attn)
        x = residual + self.mlp(x_fused_norm)
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

        return x


class Glm4vVisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 1,
        in_channels: int = 3,
        hidden_size: int = 1536,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                   self.patch_size)
zhuwenwen's avatar
zhuwenwen committed
467
468
        if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
            x = x.to(memory_format=torch.channels_last_3d)
469
470
471
472
473
474
475
476
477
478
479
480
        x = self.proj(x).view(L, self.hidden_size)
        return x


class Glm4vPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
481
        prefix: str = "",
482
        use_data_parallel: bool = False,
483
484
485
    ) -> None:
        super().__init__()
        self.hidden_size = d_model
486
487
488
489
490
491
492
493
494
        self.proj = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=bias,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
495
        self.post_projection_norm = nn.LayerNorm(self.hidden_size)
496
        self.gate_up_proj = MergedColumnParallelLinear(
497
498
499
500
            input_size=self.hidden_size,
            output_sizes=[context_dim] * 2,
            bias=bias,
            quant_config=quant_config,
501
            prefix=f"{prefix}.gate_up_proj",
502
            disable_tp=use_data_parallel,
503
        )
504
        self.down_proj = RowParallelLinear(
505
506
507
508
            context_dim,
            self.hidden_size,
            bias=bias,
            quant_config=quant_config,
509
            prefix=f"{prefix}.down_proj",
510
            disable_tp=use_data_parallel,
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        )
        self.act_fn = SiluAndMul()
        self.extra_activation_func = nn.GELU()

    def forward(self, x: torch.Tensor):
        x, _ = self.proj(x)
        x = self.extra_activation_func(self.post_projection_norm(x))
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Glm4vVisionEmbeddings(nn.Module):

    def __init__(self, config: Glm4vVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

    def forward(self, embeddings, lengths, image_shapes, h_coords,
                w_coords) -> torch.Tensor:
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
            adapted_pos_embed = torch.empty(0,
                                            hidden_size,
                                            device=device,
                                            dtype=pos_embed_weight.dtype)
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
                lengths = torch.tensor(lengths,
                                       device=device,
                                       dtype=torch.long)
            if not isinstance(image_shapes, torch.Tensor):
                image_shapes = torch.tensor(image_shapes,
                                            device=device,
                                            dtype=torch.long)

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
            pos_embed_2d = (pos_embed_weight.view(
                orig_size, orig_size,
                hidden_size).permute(2, 0,
                                     1).unsqueeze(0).to(device=device,
                                                        dtype=torch.float32))

            # Calculate target dimensions for each patch
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            # Add bounds checking for data parallel mode
            if len(lengths) > image_shapes.shape[0]:
                # In data parallel mode, some GPUs might not have all
                # image shapes
                # Use available image shapes, cycling if necessary
                target_h_list = []
                target_w_list = []
                for i in range(len(lengths)):
                    # Cycle through available shapes
                    shape_idx = i % image_shapes.shape[0]
                    target_h_list.append(image_shapes[shape_idx,
                                                      1].repeat(lengths[i]))
                    target_w_list.append(image_shapes[shape_idx,
                                                      2].repeat(lengths[i]))
                target_h = torch.cat(target_h_list).to(device=device,
                                                       dtype=torch.float32)
                target_w = torch.cat(target_w_list).to(device=device,
                                                       dtype=torch.float32)
            else:
                target_h = torch.cat([
                    image_shapes[i, 1].repeat(lengths[i])
                    for i in range(len(lengths))
                ]).to(device=device, dtype=torch.float32)
                target_w = torch.cat([
                    image_shapes[i, 2].repeat(lengths[i])
                    for i in range(len(lengths))
                ]).to(device=device, dtype=torch.float32)
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
            grid = (torch.stack((norm_w, norm_h),
                                dim=-1).unsqueeze(0).unsqueeze(2))

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(
                pos_embed_2d,
                grid,
                mode="bicubic",
                align_corners=False,
                padding_mode="border",
            )

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = (
                interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
            adapted_pos_embed = adapted_pos_embed_fp32.to(
                pos_embed_weight.dtype).to(embeddings.device)

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class Glm4vVisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (self.theta**(torch.arange(
                0,
                self.dim,
                2,
                dtype=torch.float,
                device=self.inv_freq.device,
            ) / self.dim))
            seq = torch.arange(seqlen,
                               device=self.inv_freq.device,
                               dtype=self.inv_freq.dtype)
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Glm4vVisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Glm4vVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
680
        use_data_parallel: bool = False,
681
682
683
684
685
686
687
688
689
    ) -> None:
        super().__init__()

        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
690
        self.use_data_parallel = use_data_parallel
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.out_hidden_size = vision_config.out_hidden_size

        self.patch_embed = Glm4vVisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
        self.blocks = nn.ModuleList([
            Glm4vVisionBlock(
                dim=self.hidden_size,
                num_heads=self.num_heads,
                mlp_hidden_dim=vision_config.out_hidden_size,
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}",
714
                use_data_parallel=self.use_data_parallel,
715
716
717
718
719
720
721
            ) for layer_idx in range(depth)
        ])
        self.merger = Glm4vPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=vision_config.intermediate_size,
            quant_config=quant_config,
            bias=False,
722
            prefix=f"{prefix}.merger",
723
            use_data_parallel=self.use_data_parallel,
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        )
        self.embeddings = Glm4vVisionEmbeddings(vision_config)

        self.post_conv_layernorm = RMSNorm(vision_config.hidden_size,
                                           eps=vision_config.rms_norm_eps)
        self.downsample = nn.Conv2d(
            in_channels=vision_config.hidden_size,
            out_channels=vision_config.out_hidden_size,
            kernel_size=vision_config.spatial_merge_size,
            stride=vision_config.spatial_merge_size,
        )
        self.post_layernorm = RMSNorm(vision_config.hidden_size,
                                      eps=vision_config.rms_norm_eps)

738
739
740
741
742
        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim, dtype=torch.get_default_dtype())
        if self.attn_backend != _Backend.FLASH_ATTN and \
            check_upstream_fa_availability(torch.get_default_dtype()):
            self.attn_backend = _Backend.FLASH_ATTN
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = (hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten())
            wpos_ids = (wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            ).permute(0, 2, 1, 3).flatten())
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb, pos_ids

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        if self.attn_backend == _Backend.FLASH_ATTN:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        return max_seqlen, seqlens

    def forward(
        self,
        x: torch.Tensor,
790
        grid_thw: list[list[int]],
791
    ) -> torch.Tensor:
792
793
794
        # Convert grid_thw to tensor (always expecting list format now)
        grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)
        x = self.post_conv_layernorm(x)

        # compute position embedding
        rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
                                             grid_thw[:, 0]).cumsum(
                                                 dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
        x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0],
                            image_type_ids[:, 1])

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )

        # adapter
        x = self.post_layernorm(x)

        x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size,
                   x.shape[-1])
        x = x.permute(0, 3, 1, 2)
        x = self.downsample(x).view(-1, self.out_hidden_size)
        x = self.merger(x)

        return x

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Glm4vProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
870
        return self.ctx.get_hf_config()
871
872
873
874
875
876
877

    def get_tokenizer(self):
        return self.ctx.tokenizer

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "video": 1}

878
879
    def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
880

881
882
    def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 16,
        do_resize: bool = True,
        max_image_pixels: int = 28 * 28 * 2 * 30000,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
        if do_resize:
            resized_height, resized_width = smart_resize(
                num_frames=num_frames
                if num_frames > temporal_patch_size else temporal_patch_size,
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                max_pixels=max_image_pixels,
            )
            preprocessed_size = ImageSize(width=resized_width,
                                          height=resized_height)
        else:
            preprocessed_size = ImageSize(width=image_width,
                                          height=image_height)

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
        max_image_size, _ = self._get_vision_info(image_width=9999999,
                                                  image_height=9999999)
        return max_image_size

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            max_image_pixels=28 * 28 * 2 * 6144,
        )
        return num_image_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            max_image_pixels=28 * 28 * 2 * 30000,
        )
        return num_video_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens or next_max_tokens == 0:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
        max_total_frames = self._get_max_video_frames(seq_len -
                                                      max_image_tokens)
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
                                   _MAX_FRAMES_PER_VIDEO)

        return max(max_frames_per_video, 1)

    def _get_video_second_idx(self, metadata: dict[str, Any],
                              total_frames: int) -> list[int]:
        video_processor = self.get_video_processor()

1006
        video_fps = metadata.get("fps", video_processor.fps)
1007
1008
1009
1010
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
        duration = metadata.get("duration",
                                round(max_frame_idx / video_fps) + 1)
1011
1012
1013
        do_sample_frames = metadata["do_sample_frames"]
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
1014
        else:
1015
1016
            if duration <= video_processor.max_duration:
                n = int(math.floor(duration * video_processor.fps))
1017
                frame_indices = [
1018
1019
1020
1021
                    min(
                        max_frame_idx,
                        int(math.ceil(i * video_fps / video_processor.fps)),
                    ) for i in range(n)
1022
                ]
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
            else:
                num_samples = int(video_processor.max_duration *
                                  video_processor.fps)
                if num_samples >= meta_frames:
                    frame_indices = list(range(meta_frames))
                else:
                    target_seconds = np.linspace(0,
                                                 duration,
                                                 num_samples,
                                                 endpoint=True)
                    frame_indices = [
                        min(max_frame_idx, int(math.ceil(t * video_fps)))
                        for t in target_seconds
                    ]
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)
        if len(uniq) & 1:
            uniq.append(uniq[-1])
        frame_indices = uniq

        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(0, len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    def _construct_video_placeholder(
        self,
        video_array: np.ndarray,
        metadata: dict[str, Any],
        grid_thw: torch.Tensor,
    ) -> str:
        hf_processor = self.get_hf_processor()
        tokenizer = self.get_tokenizer()
        image_processor = hf_processor.image_processor

        hf_config = self.get_hf_config()
        boi_token_id = hf_config.image_start_token_id
        eoi_token_id = hf_config.image_end_token_id
        bov_token_id = hf_config.video_start_token_id
        eov_token_id = hf_config.video_end_token_id
        merge_length = image_processor.merge_size**2

        assert isinstance(grid_thw, torch.Tensor)
        timestamps = self._get_video_second_idx(metadata, len(video_array))
        frames_idx_token = [
            tokenizer.encode(str(i), add_special_tokens=False)
            for i in timestamps
        ]
        T, H, W = grid_thw
        num_tokens_per_frame = int(H * W) // merge_length
        placeholder = []
        placeholder.append(bov_token_id)
        for frame_idx in frames_idx_token:
            placeholder.append(boi_token_id)
            placeholder.extend([hf_processor.video_token_id] *
                               num_tokens_per_frame)
            placeholder.append(eoi_token_id)
            placeholder.extend(frame_idx)
        placeholder.append(eov_token_id)

        return placeholder

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_config = self.info.get_hf_config()
        hf_processor = self.info.get_hf_processor()
        tokenizer = self.info.get_tokenizer()

        image_token: str = hf_processor.image_token
        video_token_ids = [
            hf_config.video_start_token_id,
            hf_processor.video_token_id,
            hf_config.video_end_token_id,
        ]
        video_token = tokenizer.decode(video_token_ids)

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        target_width, target_height = (
            self.info.get_image_size_with_most_features())
        target_num_frames = self.info.get_num_frames_with_most_features(
            seq_len, mm_counts)
        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images),
            "video":
            self._get_dummy_videos(
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[VideoItem]:
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
1153
                "frames_indices": [i for i in range(num_frames)],
1154
                "video_backend": "opencv",
1155
                "do_sample_frames": False,
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)

        return video_items


class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):

    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # GLM-4.1V use `image_token_id` as video placeholder, we need to
        # replace it with `video_token_id` for video processing. So we
        # separate video processing from image processing.
        if ("videos" in mm_data and isinstance(mm_data["videos"], list)
                and len(mm_data["videos"]) > 0):
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            for item in mm_data.pop("videos", []):
                video_array, metadata = item

1188
1189
1190
1191
                # don't update mm_kwargs inplace
                video_mm_kwargs = dict(**mm_kwargs)
                video_mm_kwargs["do_sample_frames"] = metadata.get(
                    "do_sample_frames", True)
1192
1193
1194

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209

                # backward compatibility for Transformers 4.55
                unuse_metadata = ["do_sample_frames"]
                if not hasattr(
                        VideoMetadata,
                        "frames_indices") and "frames_indices" in metadata:
                    unuse_metadata.append("frames_indices")

                video_mm_data["video_metadata"] = [[
                    VideoMetadata(
                        **{
                            k: metadata[k]
                            for k in metadata if k not in unuse_metadata
                        })
                ]]
1210
1211
1212
1213

                video_outputs = super()._call_hf_processor(
                    prompt="<|begin_of_video|><|video|><|end_of_video|>",
                    mm_data=video_mm_data,
1214
                    mm_kwargs=video_mm_kwargs,
1215
1216
                    tok_kwargs=tok_kwargs,
                )
1217
1218
                if not video_mm_kwargs["do_sample_frames"] and Version(
                        TRANSFORMERS_VERSION) < Version("4.56.0"):
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
                    # Transformers v4.55 has incorrect timestamps issue for
                    # skip sampling. We construct the placeholder manually to
                    # get placeholders with correct timestamps.
                    placeholder = self.info._construct_video_placeholder(
                        video_array,
                        metadata,
                        video_outputs["video_grid_thw"].squeeze(0),
                    )
                    video_placeholder = processor.tokenizer.decode(placeholder)
                else:
                    input_ids = video_outputs.pop("input_ids")
                    input_ids[input_ids == processor.image_token_id] = (
                        processor.video_token_id)
                    video_placeholder = processor.tokenizer.batch_decode(
                        input_ids)[0]
1234
1235
1236
                prompt = prompt.replace(
                    "<|begin_of_video|><|video|><|end_of_video|>",
                    video_placeholder,
1237
                    1,
1238
1239
                )

1240
                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
                pixel_values_videos_lst.append(
                    video_outputs["pixel_values_videos"])
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1267
1268
1269
        return _create_qwen2vl_field_factory(
            self.info.get_hf_config().vision_config.spatial_merge_size)(
                hf_inputs)
1270
1271
1272
1273
1274

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1275
        out_mm_kwargs: MultiModalKwargsItems,
1276
1277
1278
1279
1280
1281
1282
1283
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_processor = self.info.get_image_processor(
            **hf_processor_mm_kwargs)

        merge_length = image_processor.merge_size**2

        def get_image_replacement_glm4v(item_idx: int):
1284
1285
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
1286
1287
1288
1289
1290
1291
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_glm4v(item_idx: int):
1292
1293
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
1294
1295
1296
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
1297
1298
            placeholder = self.info._construct_video_placeholder(
                video, metadata, grid_thw)
1299
1300
1301
1302
            return PromptUpdateDetails.select_token_id(
                placeholder,
                embed_token_id=hf_processor.video_token_id,
            )
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_glm4v,
            ),
            PromptReplacement(
                modality="video",
                target="<|begin_of_video|><|video|><|end_of_video|>",
                replacement=get_video_replacement_glm4v,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
                                    SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
Jee Jee Li's avatar
Jee Jee Li committed
1331
        "gate_up_proj": ["gate_up_proj"]
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
        })

1342
1343
    supports_encoder_tp_data = True

1344
1345
1346
1347
1348
1349
1350
1351
1352
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|begin_of_image|><|image|><|end_of_image|>"
        if modality.startswith("video"):
            return "<|begin_of_video|><|video|><|end_of_video|>"

        raise ValueError("Only image or video modality is supported")

1353
1354
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1355
        config = vllm_config.model_config.hf_config
1356
1357
1358
1359
1360
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1361
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1362
1363
1364
1365

        self.visual = Glm4vVisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-5),
1366
            quant_config=quant_config,
1367
            prefix=maybe_prefix(prefix, "visual"),
1368
            use_data_parallel=self.use_data_parallel,
1369
1370
        )

Yuxuan Zhang's avatar
Yuxuan Zhang committed
1371
1372
1373
1374
1375
1376
1377
        if config.model_type == "glm4v":
            architectures = ["Glm4ForCausalLM"]
        elif config.model_type == "glm4v_moe":
            architectures = ["Glm4MoeForCausalLM"]
        else:
            architectures = None

1378
1379
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1380
1381
1382
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=architectures)
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    def _validate_and_reshape_mm_tensor(self, mm_input: object,
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of {name}. Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            if mm_input.ndim == 2:
                return mm_input
            if mm_input.ndim != 3:
                raise ValueError(f"{name} should be 2D or batched 3D tensor. "
                                 f"Got ndim: {mm_input.ndim} "
                                 f"(shape={mm_input.shape})")
1399
            return mm_input.reshape(-1, mm_input.shape[-1])
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Glm4vImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_mm_tensor(
                pixel_values, "image pixel values")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return Glm4vImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            image_embeds = self._validate_and_reshape_mm_tensor(
                image_embeds, "image embeds")
            image_grid_thw = self._validate_and_reshape_mm_tensor(
                image_grid_thw, "image grid_thw")

            return Glm4vImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[Glm4vVideoInputs]:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
1441

1442
1443
        if pixel_values_videos is None and video_embeds is None:
            return None
1444

1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
        if pixel_values_videos is not None:
            pixel_values_videos = self._validate_and_reshape_mm_tensor(
                pixel_values_videos, "video pixel values")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Glm4vVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            video_embeds = self._validate_and_reshape_mm_tensor(
                video_embeds, "video embeds")
            video_grid_thw = self._validate_and_reshape_mm_tensor(
                video_grid_thw, "video grid_thw")

            return Glm4vVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
            self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1473
        grid_thw_list = grid_thw.tolist()
1474
1475
1476
1477
1478

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1479
1480
1481
1482
1483
1484
1485
1486
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values,
                                                         grid_thw.tolist(),
                                                         rope_type="rope_3d")
            else:
                image_embeds = self.visual(pixel_values,
                                           grid_thw=grid_thw.tolist())
1487
        merge_size = self.visual.spatial_merge_size
1488
1489
1490
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
        return image_embeds.split(sizes)
1491
1492
1493
1494
1495

    def _process_video_input(
            self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]:
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1496
        grid_thw_list = grid_thw.tolist()
1497
1498
1499
1500
1501
1502

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
                self.visual.dtype)
1503
1504
1505
1506
1507
1508
1509
1510
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(self.visual,
                                                         pixel_values_videos,
                                                         grid_thw.tolist(),
                                                         rope_type="rope_3d")
            else:
                video_embeds = self.visual(pixel_values_videos,
                                           grid_thw=grid_thw.tolist())
1511
1512
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1513
1514
1515
        sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
                 (merge_size * merge_size)).tolist()
        return video_embeds.split(sizes)
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (input_key in ("pixel_values", "image_embeds")
                    and "image" not in mm_input_by_modality):
                mm_input_by_modality["image"] = (
                    self._parse_and_validate_image_input(**kwargs))
            if (input_key in ("pixel_values_videos", "video_embeds")
                    and "video" not in mm_input_by_modality):
                mm_input_by_modality["video"] = (
                    self._parse_and_validate_video_input(**kwargs))
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
            **kwargs)
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1544
        # tensor corresponding to a multimodal data item (image or video).
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += vision_embeddings
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings += video_embeddings
        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if (multimodal_embeddings is not None
                and len(multimodal_embeddings) != 0
                and all(embed.numel() > 0 for embed in multimodal_embeddings)):
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                multimodal_embeddings,
                [self.config.image_token_id, self.config.video_token_id],
            )
        return inputs_embeds

    def get_input_embeddings_v0(
        self,
        input_ids: torch.Tensor,
        image_input: Optional[Glm4vImageInputs] = None,
        video_input: Optional[Glm4vVideoInputs] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.get_input_embeddings(input_ids)
        if image_input is not None:
            image_embeds = self._process_image_input(image_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                image_embeds,
                placeholder_token_id=self.config.image_token_id,
            )

        if video_input is not None:
            video_embeds = self._process_video_input(video_input)
            inputs_embeds = merge_multimodal_embeddings(
                input_ids,
                inputs_embeds,
                video_embeds,
                placeholder_token_id=self.config.video_token_id,
            )
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Run forward pass for GLM-4V.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for GLM-4V
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
1620
1621
1622
1623
            intermediate_tensors: Optional intermediate tensors for pipeline
                parallelism.
            inputs_embeds: Optional pre-computed input embeddings.
            **kwargs: Additional keyword arguments.
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner from
        # `get_multimodal_embeddings` and `get_input_embeddings`, this
        # condition is only for v0 compatibility.
        elif inputs_embeds is None:
            image_input = self._parse_and_validate_image_input(**kwargs)
            video_input = self._parse_and_validate_video_input(**kwargs)

            if image_input is None and video_input is None:
                inputs_embeds = None
            else:
                if uses_mrope(self.config):
                    assert positions.ndim == 2 and positions.size(0) == 3, (
                        "multimodal section rotary embedding requires "
                        f"(3, seq_len) positions, but got {positions.size()}")
                inputs_embeds = self.get_input_embeddings_v0(
                    input_ids,
                    image_input=image_input,
                    video_input=video_input)
                input_ids = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
1660
        return self.language_model.compute_logits(hidden_states)
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
Jee Jee Li's avatar
Jee Jee Li committed
1672
            language_model="language_model.model",
1673
1674
1675
            connector="visual.merger.",
            tower_model="visual.",
        )
Jee Jee Li's avatar
Jee Jee Li committed
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }