"vllm/vscode:/vscode.git/clone" did not exist on "10383887e03412196a2689b9398290719c4797bf"
glm4_1v.py 60.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The ZhipuAI Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""

import math
30
from collections.abc import Callable, Iterable, Mapping, Sequence
31
from functools import partial
32
from typing import Annotated, Any, Literal, TypeAlias
33
34
35
36
37
38

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
39
from packaging.version import Version
40
from transformers import BatchFeature
41
from transformers import __version__ as TRANSFORMERS_VERSION
Yuxuan Zhang's avatar
Yuxuan Zhang committed
42
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
43
from transformers.models.glm4v.image_processing_glm4v import (
44
45
46
47
    Glm4vImageProcessor,
    smart_resize,
)
from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
48
49
from transformers.video_utils import VideoMetadata

50
from vllm.attention.backends.registry import _Backend
51
52
53
54
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
55
from vllm.config import VllmConfig
56
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
57
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
58
59
60
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
61
62
63
64
65
66
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
67
68
69
70
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
85
86
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
87
from vllm.utils.tensor_schema import TensorSchema, TensorShape
88
89

from ..layers.activation import SiluAndMul
90
91
92
93
94
95
96
97
98
99
100
101
102
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
103
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
104
105
106
107
108
109
110
111
112

logger = init_logger(__name__)

# For profile run
_MAX_FRAMES_PER_VIDEO = 600

# === Vision Inputs === #


113
class Glm4vImagePixelInputs(TensorSchema):
114
    """
115
116
117
118
119
    Dimensions:
        - np: Number of patches
        - cpp: Number of channels * patch_size * patch_size
        - ni: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
120
    """
121

122
    type: Literal["pixel_values"] = "pixel_values"
123

124
125
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
126
127


128
class Glm4vImageEmbeddingInputs(TensorSchema):
129
    """
130
131
132
133
134
    Dimensions:
        - f: Number of image features (varies based on image resolution)
        - h: Hidden size (must match language model backbone)
        - n: Number of images
        - g: Grid dimensions (3 for grid_t, grid_h, grid_w)
135
    """
136

137
138
139
140
    type: Literal["image_embeds"] = "image_embeds"

    image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
141
142


143
Glm4vImageInputs: TypeAlias = Glm4vImagePixelInputs | Glm4vImageEmbeddingInputs
144
145


146
class Glm4vVideoPixelInputs(TensorSchema):
147
    """
148
149
150
151
152
    Dimensions:
        - np: Number of patches
        - ctpp: Number of channels * temporal_patch_size *
            patch_size * patch_size
        - f: Number of frames
153
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
154
          video, grid_h, grid_w)
155
    """
156

157
    type: Literal["pixel_values_videos"] = "pixel_values_videos"
158

159
    pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")]
160
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
161
162


163
class Glm4vVideoEmbeddingInputs(TensorSchema):
164
    """
165
166
167
    Dimensions:
        - p: Number of video patches across all frames
        - h: Hidden size (must match language model backbone)
168
        - f: Number of frames
169
        - g: Grid dimensions (3 for grid_t which is usually 1 for processed
170
          video, grid_h, grid_w)
171
    """
172

173
    type: Literal["video_embeds"] = "video_embeds"
174

175
    video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")]
176
    video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
177
178


179
Glm4vVideoInputs: TypeAlias = Glm4vVideoPixelInputs | Glm4vVideoEmbeddingInputs
180

181
# ==== Vision Encoder ==== #
182
183
184
185
186
187
188
189


class Glm4vVisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
190
        quant_config: QuantizationConfig | None = None,
191
        prefix: str = "",
192
        use_data_parallel: bool = False,
193
194
    ):
        super().__init__()
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=in_features,
            output_sizes=[hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
            disable_tp=use_data_parallel,
        )
        self.down_proj = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
            disable_tp=use_data_parallel,
        )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
    """All-gather the input tensor interleavely across model parallel group."""
    import torch.distributed as dist

    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
    dist.all_gather(
        gathered_tensors,
        local_tensor,
        group=parallel_state.get_tp_group().device_group,
    )

    gathered_tensors_split = [
232
        torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    ]
    ordered_tensors = [
        tensor for pair in zip(*gathered_tensors_split) for tensor in pair
    ]
    result_tensor = torch.cat(ordered_tensors, dim=-1)
    return result_tensor


class Glm4vVisionAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        projection_size: int,
247
        quant_config: QuantizationConfig | None = None,
248
        prefix: str = "",
249
        use_data_parallel: bool = False,
250
        attn_backend_override: _Backend | None = None,
251
252
253
    ) -> None:
        super().__init__()
        # Per attention head and per partition values.
254
255
256
257
258
259
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = (
            0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
        )
260
        self.hidden_size_per_attention_head = dist_utils.divide(
261
262
            projection_size, num_heads
        )
263
        self.num_attention_heads_per_partition = dist_utils.divide(
264
265
            num_heads, self.tp_size
        )
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        self.qkv = QKVParallelLinear(
            hidden_size=embed_dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            total_num_kv_heads=num_heads,
            bias=False,
            quant_config=quant_config,
            # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
            prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=projection_size,
            output_size=embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            bias=False,
            disable_tp=use_data_parallel,
        )
286
287

        # Detect attention implementation.
288
289
        self.attn_backend = get_vit_attn_backend(
            head_size=self.hidden_size_per_attention_head,
290
            dtype=torch.get_default_dtype(),
291
            attn_backend_override=attn_backend_override,
292
        )
293
        self.use_upstream_fa = False
294

295
296
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
297
298
                self.attn_backend,
                self.use_upstream_fa,
299
                attn_backend_override=attn_backend_override,
300
            )
301
        )
302

303
        if self.attn_backend not in {
304
305
306
307
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
308
309
        }:
            raise RuntimeError(
310
311
                f"GLM-4V does not support {self.attn_backend} backend now."
            )
312

313
        self.is_flash_attn_backend = self.attn_backend in {
314
315
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
316
317
        }

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
        # [s, b, 3 * head * head_dim]
        seq_len, bs, _ = qkv.shape

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
        q, k, v = qkv.chunk(3, dim=2)

        # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
        new_shape = (
            seq_len,
            bs,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
        )
        q, k, v = (x.view(*new_shape) for x in (q, k, v))
        return q, k, v

    def forward(
336
337
338
339
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
340
341
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
342
343
344
345
346
347
348
349
    ) -> torch.Tensor:
        # [s, b, c] --> [s, b, head * 3 * head_dim]
        x, _ = self.qkv(x)

        # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
        q, k, v = self.split_qkv(x)
        batch_size = q.shape[1]

350
        q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
351
        if rotary_pos_emb is not None:
352
353
354
355
            # [2 * b, s, heads, head_dim]
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)
356

357
        if self.is_flash_attn_backend:
358
359
            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

360
            output = self.flash_attn_varlen_func(
361
362
363
364
365
366
367
368
369
370
371
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0,
                causal=False,
            )

372
373
374
            context_layer = rearrange(
                output, "(b s) h d -> s b (h d)", b=batch_size
            ).contiguous()
375
376
377
378
379
380
381
382
383
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            outputs = []
            for i in range(1, len(cu_seqlens)):
                start_idx = cu_seqlens[i - 1]
                end_idx = cu_seqlens[i]
                q_i = q[:, start_idx:end_idx]
                k_i = k[:, start_idx:end_idx]
                v_i = v[:, start_idx:end_idx]
384
385
386
387
                q_i, k_i, v_i = (
                    rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
                )
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
388
389
390
                output_i = rearrange(output_i, "b h s d -> b s h d ")
                outputs.append(output_i)
            context_layer = torch.cat(outputs, dim=1)
391
392
393
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
394
395
396
397
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask

398
399
400
            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
401
402

            context_layer = xops.memory_efficient_attention_forward(
403
404
405
406
407
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
            context_layer = rearrange(
                context_layer, "b s h d -> s b (h d)"
            ).contiguous()
408
409
410
411
412
413
414
415
416
417
418

        output, _ = self.proj(context_layer)
        return output


class Glm4vVisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
419
420
        norm_layer: Callable[[int], nn.Module] | None = None,
        quant_config: QuantizationConfig | None = None,
421
        prefix: str = "",
422
        use_data_parallel: bool = False,
423
        attn_backend_override: _Backend | None = None,
424
425
426
427
428
429
430
431
432
433
434
435
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Glm4vVisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
436
            use_data_parallel=use_data_parallel,
437
            attn_backend_override=attn_backend_override,
438
439
440
441
442
443
        )
        self.mlp = Glm4vVisionMLP(
            dim,
            mlp_hidden_dim,
            bias=False,
            quant_config=quant_config,
444
            prefix=f"{prefix}.mlp",
445
            use_data_parallel=use_data_parallel,
446
447
448
        )

    def forward(
449
450
451
452
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
453
454
        max_seqlen: int | None = None,  # Only used for Flash Attention
        seqlens: list[int] | None = None,  # Only used for xFormers
455
    ) -> torch.Tensor:
456
        x_attn = self.attn(
457
458
459
460
461
462
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
463
464
        x_fused_norm, residual = self.norm2(x, residual=x_attn)
        x = residual + self.mlp(x_fused_norm)
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492

        return x


class Glm4vVisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 1,
        in_channels: int = 3,
        hidden_size: int = 1536,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
493
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
494
495
496
497
498
499
500
501
502
        x = self.proj(x).view(L, self.hidden_size)
        return x


class Glm4vPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
503
        quant_config: QuantizationConfig | None = None,
504
        bias: bool = False,
505
        prefix: str = "",
506
        use_data_parallel: bool = False,
507
508
509
    ) -> None:
        super().__init__()
        self.hidden_size = d_model
510
511
512
513
514
515
516
517
518
        self.proj = ColumnParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=bias,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
519
        self.post_projection_norm = nn.LayerNorm(self.hidden_size)
520
        self.gate_up_proj = MergedColumnParallelLinear(
521
522
523
524
            input_size=self.hidden_size,
            output_sizes=[context_dim] * 2,
            bias=bias,
            quant_config=quant_config,
525
            prefix=f"{prefix}.gate_up_proj",
526
            disable_tp=use_data_parallel,
527
        )
528
        self.down_proj = RowParallelLinear(
529
530
531
532
            context_dim,
            self.hidden_size,
            bias=bias,
            quant_config=quant_config,
533
            prefix=f"{prefix}.down_proj",
534
            disable_tp=use_data_parallel,
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        )
        self.act_fn = SiluAndMul()
        self.extra_activation_func = nn.GELU()

    def forward(self, x: torch.Tensor):
        x, _ = self.proj(x)
        x = self.extra_activation_func(self.post_projection_norm(x))
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Glm4vVisionEmbeddings(nn.Module):
    def __init__(self, config: Glm4vVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

556
        self.num_patches = (self.image_size // self.patch_size) ** 2
557
        self.num_positions = self.num_patches
558
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
559
560
561
562
563
564
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )

565
566
567
    def forward(
        self, embeddings, lengths, image_shapes, h_coords, w_coords
    ) -> torch.Tensor:
568
569
570
571
572
573
574
575
576
577
        pos_embed_weight = self.position_embedding.weight
        hidden_size = pos_embed_weight.shape[1]
        total_seq = h_coords.shape[0]
        device = pos_embed_weight.device

        # Move coordinates to correct device
        h_coords, w_coords = h_coords.to(device), w_coords.to(device)

        # Handle empty sequence case
        if total_seq == 0:
578
579
580
            adapted_pos_embed = torch.empty(
                0, hidden_size, device=device, dtype=pos_embed_weight.dtype
            )
581
582
583
        else:
            # Convert inputs to tensors if needed
            if isinstance(lengths, list):
584
                lengths = torch.tensor(lengths, device=device, dtype=torch.long)
585
            if not isinstance(image_shapes, torch.Tensor):
586
587
588
                image_shapes = torch.tensor(
                    image_shapes, device=device, dtype=torch.long
                )
589
590
591
592

            # Prepare 2D position embedding
            orig_size_sq = pos_embed_weight.shape[0]
            orig_size = int(orig_size_sq**0.5)
593
594
595
596
597
598
            pos_embed_2d = (
                pos_embed_weight.view(orig_size, orig_size, hidden_size)
                .permute(2, 0, 1)
                .unsqueeze(0)
                .to(device=device, dtype=torch.float32)
            )
599
600

            # Calculate target dimensions for each patch
601
602
603
604
605
606
607
608
609
610
            # Add bounds checking for data parallel mode
            if len(lengths) > image_shapes.shape[0]:
                # In data parallel mode, some GPUs might not have all
                # image shapes
                # Use available image shapes, cycling if necessary
                target_h_list = []
                target_w_list = []
                for i in range(len(lengths)):
                    # Cycle through available shapes
                    shape_idx = i % image_shapes.shape[0]
611
612
613
614
615
616
617
618
                    target_h_list.append(image_shapes[shape_idx, 1].repeat(lengths[i]))
                    target_w_list.append(image_shapes[shape_idx, 2].repeat(lengths[i]))
                target_h = torch.cat(target_h_list).to(
                    device=device, dtype=torch.float32
                )
                target_w = torch.cat(target_w_list).to(
                    device=device, dtype=torch.float32
                )
619
            else:
620
621
622
623
624
625
                target_h = torch.cat(
                    [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
                target_w = torch.cat(
                    [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]
                ).to(device=device, dtype=torch.float32)
626
627
628
629
630
631
632
633

            # Normalize coordinates to [-1, 1] range for grid_sample
            h_coords = h_coords.to(device=device, dtype=torch.float32)
            w_coords = w_coords.to(device=device, dtype=torch.float32)
            norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
            norm_h = ((h_coords + 0.5) / target_h) * 2 - 1

            # Create sampling grid
634
            grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
635
636
637
638
639
640
641
642
643
644
645
646

            # Perform bicubic interpolation
            interpolated_embed_fp32 = F.grid_sample(
                pos_embed_2d,
                grid,
                mode="bicubic",
                align_corners=False,
                padding_mode="border",
            )

            # Reshape and convert back to original dtype
            adapted_pos_embed_fp32 = (
647
648
649
650
651
                interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
            )
            adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(
                embeddings.device
            )
652
653
654
655
656
657
658
659
660
661
662

        # Add adapted position encoding to embeddings
        embeddings = embeddings + adapted_pos_embed
        return embeddings


class Glm4vVisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
663
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
664
665
666
667
668
669
670
671
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0,
                        self.dim,
                        2,
                        dtype=torch.float,
                        device=self.inv_freq.device,
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
688
689
690
691
692
693
694
695
696
697
698
699
700
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Glm4vVisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Glm4vVisionConfig,
        norm_eps: float = 1e-6,
701
        quant_config: QuantizationConfig | None = None,
702
        prefix: str = "",
703
        use_data_parallel: bool = False,
704
        attn_backend_override: _Backend | None = None,
705
706
707
708
709
710
711
712
713
    ) -> None:
        super().__init__()

        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
714
        self.use_data_parallel = use_data_parallel
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.out_hidden_size = vision_config.out_hidden_size

        self.patch_embed = Glm4vVisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
730
731
732
733
734
735
736
737
738
739
        self.blocks = nn.ModuleList(
            [
                Glm4vVisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.out_hidden_size,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    use_data_parallel=self.use_data_parallel,
740
                    attn_backend_override=attn_backend_override,
741
742
743
744
                )
                for layer_idx in range(depth)
            ]
        )
745
746
747
748
749
        self.merger = Glm4vPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=vision_config.intermediate_size,
            quant_config=quant_config,
            bias=False,
750
            prefix=f"{prefix}.merger",
751
            use_data_parallel=self.use_data_parallel,
752
753
754
        )
        self.embeddings = Glm4vVisionEmbeddings(vision_config)

755
756
757
        self.post_conv_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
758
759
760
761
762
763
        self.downsample = nn.Conv2d(
            in_channels=vision_config.hidden_size,
            out_channels=vision_config.out_hidden_size,
            kernel_size=vision_config.spatial_merge_size,
            stride=vision_config.spatial_merge_size,
        )
764
765
766
        self.post_layernorm = RMSNorm(
            vision_config.hidden_size, eps=vision_config.rms_norm_eps
        )
767

768
        self.attn_backend = get_vit_attn_backend(
769
770
771
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
772
773
774
775
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
776
            self.attn_backend = _Backend.FLASH_ATTN
777
778
779
780
781
782
783
784
785
786
787
788
789
790

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
812
813
814
815
816
817
818
819
820
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb, pos_ids

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
821
    ) -> tuple[int | None, list[int] | None]:
822
823
        max_seqlen, seqlens = None, None
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
824
825
826
827
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
828
829
830
831
832
833
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        return max_seqlen, seqlens

    def forward(
        self,
        x: torch.Tensor,
834
        grid_thw: list[list[int]],
835
    ) -> torch.Tensor:
836
837
838
        # Convert grid_thw to tensor (always expecting list format now)
        grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)

839
840
841
842
843
844
845
846
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)
        x = self.post_conv_layernorm(x)

        # compute position embedding
        rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
        # compute cu_seqlens
847
848
849
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
850
851
852
853
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # pre-compute seqlens for attn mask to reduce cuMemcpy operations
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
854
855
856
        x = self.embeddings(
            x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
        )
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )

        # adapter
        x = self.post_layernorm(x)

872
        x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
873
874
875
876
877
878
        x = x.permute(0, 3, 1, 2)
        x = self.downsample(x).view(-1, self.out_hidden_size)
        x = self.merger(x)

        return x

879
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
903
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
904
905
906
907
908
909
910
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class Glm4vProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
Yuxuan Zhang's avatar
Yuxuan Zhang committed
911
        return self.ctx.get_hf_config()
912
913
914
915

    def get_tokenizer(self):
        return self.ctx.tokenizer

916
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
917
918
        return {"image": None, "video": 1}

919
920
    def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
921

922
923
    def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
        return self.get_hf_processor(**kwargs).video_processor
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941

    def _get_vision_info(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int = 16,
        do_resize: bool = True,
        max_image_pixels: int = 28 * 28 * 2 * 30000,
    ) -> tuple[ImageSize, int]:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        patch_size = vision_config.patch_size
        merge_size = vision_config.spatial_merge_size
        temporal_patch_size = vision_config.temporal_patch_size
        if do_resize:
            resized_height, resized_width = smart_resize(
                num_frames=num_frames
942
943
                if num_frames > temporal_patch_size
                else temporal_patch_size,
944
945
946
947
948
                height=image_height,
                width=image_width,
                factor=patch_size * merge_size,
                max_pixels=max_image_pixels,
            )
949
            preprocessed_size = ImageSize(width=resized_width, height=resized_height)
950
        else:
951
            preprocessed_size = ImageSize(width=image_width, height=image_height)
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966

        # NOTE: Frames are padded to be divisible by `temporal_patch_size`
        # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
        padded_num_frames = num_frames + num_frames % temporal_patch_size

        grid_t = max(padded_num_frames // temporal_patch_size, 1)
        grid_h = preprocessed_size.height // patch_size
        grid_w = preprocessed_size.width // patch_size

        num_patches = grid_t * grid_h * grid_w
        num_vision_tokens = num_patches // (merge_size**2)

        return preprocessed_size, num_vision_tokens

    def get_image_size_with_most_features(self) -> ImageSize:
967
968
969
        max_image_size, _ = self._get_vision_info(
            image_width=9999999, image_height=9999999
        )
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        return max_image_size

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        _, num_image_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            max_image_pixels=28 * 28 * 2 * 6144,
        )
        return num_image_tokens

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )

    def get_num_video_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        num_frames: int,
    ) -> int:
        _, num_video_tokens = self._get_vision_info(
            image_width=image_width,
            image_height=image_height,
            num_frames=num_frames,
            max_image_pixels=28 * 28 * 2 * 30000,
        )
        return num_video_tokens

    def _get_max_video_frames(self, max_tokens: int) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        num_frames = 0

        while True:
            next_num_frames = num_frames + 1
            next_max_tokens = self.get_num_video_tokens(
                image_width=target_width,
                image_height=target_height,
                num_frames=next_num_frames,
            )
            if next_max_tokens > max_tokens or next_max_tokens == 0:
                break

            num_frames = next_num_frames

        return num_frames

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        max_image_tokens = self.get_max_image_tokens() * max_images
1036
1037
1038
1039
        max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens)
        max_frames_per_video = min(
            max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
        )
1040
1041
1042

        return max(max_frames_per_video, 1)

1043
1044
1045
    def _get_video_second_idx(
        self, metadata: dict[str, Any], total_frames: int
    ) -> list[int]:
1046
1047
        video_processor = self.get_video_processor()

1048
        video_fps = metadata.get("fps", video_processor.fps)
1049
1050
        meta_frames = metadata.get("total_num_frames", total_frames)
        max_frame_idx = meta_frames - 1
1051
        duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1)
1052
1053
1054
        do_sample_frames = metadata["do_sample_frames"]
        if not do_sample_frames:
            frame_indices = metadata["frames_indices"]
1055
        else:
1056
1057
            if duration <= video_processor.max_duration:
                n = int(math.floor(duration * video_processor.fps))
1058
                frame_indices = [
1059
1060
1061
                    min(
                        max_frame_idx,
                        int(math.ceil(i * video_fps / video_processor.fps)),
1062
1063
                    )
                    for i in range(n)
1064
                ]
1065
            else:
1066
                num_samples = int(video_processor.max_duration * video_processor.fps)
1067
1068
1069
                if num_samples >= meta_frames:
                    frame_indices = list(range(meta_frames))
                else:
1070
1071
1072
                    target_seconds = np.linspace(
                        0, duration, num_samples, endpoint=True
                    )
1073
1074
1075
1076
                    frame_indices = [
                        min(max_frame_idx, int(math.ceil(t * video_fps)))
                        for t in target_seconds
                    ]
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

        seen, uniq = set(), []
        for idx in frame_indices:
            if idx not in seen:
                seen.add(idx)
                uniq.append(idx)
        if len(uniq) & 1:
            uniq.append(uniq[-1])
        frame_indices = uniq

        full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
        timestamps_list = full_second_idxs[::2]
        selected_timestamps = []
        for idx in range(0, len(timestamps_list)):
            selected_timestamps.append(timestamps_list[idx])
        return selected_timestamps

1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
    def _construct_video_placeholder(
        self,
        video_array: np.ndarray,
        metadata: dict[str, Any],
        grid_thw: torch.Tensor,
    ) -> str:
        hf_processor = self.get_hf_processor()
        tokenizer = self.get_tokenizer()
        image_processor = hf_processor.image_processor

        hf_config = self.get_hf_config()
        boi_token_id = hf_config.image_start_token_id
        eoi_token_id = hf_config.image_end_token_id
        bov_token_id = hf_config.video_start_token_id
        eov_token_id = hf_config.video_end_token_id
        merge_length = image_processor.merge_size**2

        assert isinstance(grid_thw, torch.Tensor)
        timestamps = self._get_video_second_idx(metadata, len(video_array))
        frames_idx_token = [
1114
            tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps
1115
1116
1117
1118
1119
1120
1121
        ]
        T, H, W = grid_thw
        num_tokens_per_frame = int(H * W) // merge_length
        placeholder = []
        placeholder.append(bov_token_id)
        for frame_idx in frames_idx_token:
            placeholder.append(boi_token_id)
1122
            placeholder.extend([hf_processor.video_token_id] * num_tokens_per_frame)
1123
1124
1125
1126
1127
1128
            placeholder.append(eoi_token_id)
            placeholder.extend(frame_idx)
        placeholder.append(eov_token_id)

        return placeholder

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

        hf_config = self.info.get_hf_config()
        hf_processor = self.info.get_hf_processor()
        tokenizer = self.info.get_tokenizer()

        image_token: str = hf_processor.image_token
        video_token_ids = [
            hf_config.video_start_token_id,
            hf_processor.video_token_id,
            hf_config.video_end_token_id,
        ]
        video_token = tokenizer.decode(video_token_ids)

        return image_token * num_images + video_token * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
1153
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
1154
1155
1156
1157
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_videos = mm_counts.get("video", 0)

1158
        target_width, target_height = self.info.get_image_size_with_most_features()
1159
        target_num_frames = self.info.get_num_frames_with_most_features(
1160
1161
            seq_len, mm_counts
        )
1162
1163
1164
1165

        image_overrides = mm_options.get("image") if mm_options else None
        video_overrides = mm_options.get("video") if mm_options else None

1166
        return {
1167
1168
1169
1170
1171
1172
1173
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
            "video": self._get_dummy_videos(
1174
1175
1176
1177
                width=target_width,
                height=target_height,
                num_frames=target_num_frames,
                num_videos=num_videos,
1178
                overrides=video_overrides,
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
            ),
        }

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
1189
        overrides: VideoDummyOptions | None = None,
1190
    ) -> list[VideoItem]:
1191
1192
1193
1194
1195
1196
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
1197
1198
1199
                        overrides.num_frames,
                        num_frames,
                    )
1200
1201
1202
1203
1204
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
1205
1206
1207
1208
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
1209
1210
1211
1212
1213
1214
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
1215
1216
1217
                        overrides.height,
                        height,
                    )
1218
                height = min(height, overrides.height)
1219

1220
1221
1222
1223
1224
1225
1226
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
        video_items = []
        for i in range(num_videos):
            video_metadata = {
                "fps": 2.0,
                "duration": num_frames / 2.0,
                "total_num_frames": num_frames,
1227
                "frames_indices": [i for i in range(num_frames)],
1228
                "video_backend": "opencv",
1229
                "do_sample_frames": False,
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
            }
            video_item = (video.copy(), video_metadata)
            video_items.append(video_item)

        return video_items


class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
    def _get_data_parser(self) -> MultiModalDataParser:
        return MultiModalDataParser(video_needs_metadata=True)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        processor = self.info.get_hf_processor(**mm_kwargs)

        # GLM-4.1V use `image_token_id` as video placeholder, we need to
        # replace it with `video_token_id` for video processing. So we
        # separate video processing from image processing.
1254
1255
1256
1257
1258
        if (
            "videos" in mm_data
            and isinstance(mm_data["videos"], list)
            and len(mm_data["videos"]) > 0
        ):
1259
1260
1261
1262
1263
            video_grid_thw_lst = []
            pixel_values_videos_lst = []
            for item in mm_data.pop("videos", []):
                video_array, metadata = item

1264
1265
1266
                # don't update mm_kwargs inplace
                video_mm_kwargs = dict(**mm_kwargs)
                video_mm_kwargs["do_sample_frames"] = metadata.get(
1267
1268
                    "do_sample_frames", True
                )
1269
1270
1271

                video_mm_data = dict()
                video_mm_data["videos"] = [[video_array]]
1272
1273
1274

                # backward compatibility for Transformers 4.55
                unuse_metadata = ["do_sample_frames"]
1275
1276
1277
1278
                if (
                    not hasattr(VideoMetadata, "frames_indices")
                    and "frames_indices" in metadata
                ):
1279
1280
                    unuse_metadata.append("frames_indices")

1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
                video_mm_data["video_metadata"] = [
                    [
                        VideoMetadata(
                            **{
                                k: metadata[k]
                                for k in metadata
                                if k not in unuse_metadata
                            }
                        )
                    ]
                ]
1292
1293
1294
1295

                video_outputs = super()._call_hf_processor(
                    prompt="<|begin_of_video|><|video|><|end_of_video|>",
                    mm_data=video_mm_data,
1296
                    mm_kwargs=video_mm_kwargs,
1297
1298
                    tok_kwargs=tok_kwargs,
                )
1299
                if not video_mm_kwargs["do_sample_frames"] and Version(
1300
1301
                    TRANSFORMERS_VERSION
                ) < Version("4.56.0"):
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
                    # Transformers v4.55 has incorrect timestamps issue for
                    # skip sampling. We construct the placeholder manually to
                    # get placeholders with correct timestamps.
                    placeholder = self.info._construct_video_placeholder(
                        video_array,
                        metadata,
                        video_outputs["video_grid_thw"].squeeze(0),
                    )
                    video_placeholder = processor.tokenizer.decode(placeholder)
                else:
                    input_ids = video_outputs.pop("input_ids")
                    input_ids[input_ids == processor.image_token_id] = (
1314
1315
1316
                        processor.video_token_id
                    )
                    video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
1317
1318
1319
                prompt = prompt.replace(
                    "<|begin_of_video|><|video|><|end_of_video|>",
                    video_placeholder,
1320
                    1,
1321
1322
                )

1323
                video_grid_thw_lst.append(video_outputs["video_grid_thw"])
1324
                pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
            video_outputs = dict(
                pixel_values_videos=torch.cat(pixel_values_videos_lst),
                video_grid_thw=torch.cat(video_grid_thw_lst),
            )
        else:
            video_outputs = dict()

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )
        combined_outputs = dict(
            processed_outputs,
            **video_outputs,
        )
        return BatchFeature(combined_outputs)

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
1349
        return _create_qwen2vl_field_factory(
1350
1351
            self.info.get_hf_config().vision_config.spatial_merge_size
        )(hf_inputs)
1352
1353
1354
1355
1356

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
1357
        out_mm_kwargs: MultiModalKwargsItems,
1358
1359
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
1360
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
1361
1362
1363
1364

        merge_length = image_processor.merge_size**2

        def get_image_replacement_glm4v(item_idx: int):
1365
1366
            out_item = out_mm_kwargs["image"][item_idx]
            grid_thw = out_item["image_grid_thw"].data
1367
1368
1369
1370
1371
1372
            assert isinstance(grid_thw, torch.Tensor)

            num_tokens = int(grid_thw.prod()) // merge_length
            return [hf_processor.image_token_id] * num_tokens

        def get_video_replacement_glm4v(item_idx: int):
1373
1374
            out_item = out_mm_kwargs["video"][item_idx]
            grid_thw = out_item["video_grid_thw"].data
1375
1376
1377
            assert isinstance(grid_thw, torch.Tensor)

            video, metadata = mm_items["video"][item_idx]
1378
            placeholder = self.info._construct_video_placeholder(
1379
1380
                video, metadata, grid_thw
            )
1381
1382
1383
1384
            return PromptUpdateDetails.select_token_id(
                placeholder,
                embed_token_id=hf_processor.video_token_id,
            )
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404

        return [
            PromptReplacement(
                modality="image",
                target=hf_processor.image_token,
                replacement=get_image_replacement_glm4v,
            ),
            PromptReplacement(
                modality="video",
                target="<|begin_of_video|><|video|><|end_of_video|>",
                replacement=get_video_replacement_glm4v,
            ),
        ]


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
1405
1406
1407
class Glm4vForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
1408
1409
    merge_by_field_config = True

1410
1411
1412
1413
1414
1415
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
1416
        "gate_up_proj": ["gate_up_proj"],
1417
1418
1419
1420
1421
1422
1423
1424
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
1425
1426
        }
    )
1427

1428
1429
    supports_encoder_tp_data = True

1430
    @classmethod
1431
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1432
1433
1434
1435
1436
1437
1438
        if modality.startswith("image"):
            return "<|begin_of_image|><|image|><|end_of_image|>"
        if modality.startswith("video"):
            return "<|begin_of_video|><|video|><|end_of_video|>"

        raise ValueError("Only image or video modality is supported")

1439
1440
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1441
        config = vllm_config.model_config.hf_config
1442
1443
1444
1445
1446
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config
1447
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1448

1449
1450
1451
1452
1453
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
1454
1455
1456
        self.visual = Glm4vVisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-5),
1457
            quant_config=quant_config,
1458
            prefix=maybe_prefix(prefix, "visual"),
1459
            use_data_parallel=self.use_data_parallel,
1460
            attn_backend_override=attn_backend_override,
1461
1462
        )

Yuxuan Zhang's avatar
Yuxuan Zhang committed
1463
1464
1465
1466
1467
1468
1469
        if config.model_type == "glm4v":
            architectures = ["Glm4ForCausalLM"]
        elif config.model_type == "glm4v_moe":
            architectures = ["Glm4MoeForCausalLM"]
        else:
            architectures = None

1470
1471
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
Yuxuan Zhang's avatar
Yuxuan Zhang committed
1472
1473
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
1474
1475
            architectures=architectures,
        )
1476
1477

        self.make_empty_intermediate_tensors = (
1478
1479
            self.language_model.make_empty_intermediate_tensors
        )
1480
1481

    def _parse_and_validate_image_input(
1482
        self, **kwargs: object
1483
    ) -> Glm4vImageInputs | None:
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Glm4vImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Glm4vImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
1506
        self, **kwargs: object
1507
    ) -> Glm4vVideoInputs | None:
1508
1509
1510
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
1511

1512
1513
        if pixel_values_videos is None and video_embeds is None:
            return None
1514

1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        if pixel_values_videos is not None:
            return Glm4vVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
            )

        if video_embeds is not None:
            return Glm4vVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
            )

    def _process_image_input(
1530
1531
        self, image_input: Glm4vImageInputs
    ) -> tuple[torch.Tensor, ...]:
1532
1533
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
1534
        grid_thw_list = grid_thw.tolist()
1535
1536
1537
1538
1539

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.visual.dtype)
1540
            if self.use_data_parallel:
1541
1542
1543
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
                )
1544
            else:
1545
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
1546
        merge_size = self.visual.spatial_merge_size
1547
1548
1549
1550
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1551
        return image_embeds.split(sizes)
1552
1553

    def _process_video_input(
1554
1555
        self, video_input: Glm4vVideoInputs
    ) -> tuple[torch.Tensor, ...]:
1556
1557
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
1558
        grid_thw_list = grid_thw.tolist()
1559
1560
1561
1562
1563

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"].type(
1564
1565
                self.visual.dtype
            )
1566
            if self.use_data_parallel:
1567
1568
1569
1570
1571
1572
                return run_dp_sharded_mrope_vision_model(
                    self.visual,
                    pixel_values_videos,
                    grid_thw.tolist(),
                    rope_type="rope_3d",
                )
1573
            else:
1574
1575
1576
                video_embeds = self.visual(
                    pixel_values_videos, grid_thw=grid_thw.tolist()
                )
1577
1578
        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
1579
1580
1581
1582
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
1583
        return video_embeds.split(sizes)
1584
1585
1586
1587
1588
1589
1590

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
1605
1606
1607
1608
1609
1610
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
1611
        self, **kwargs: object
1612
    ) -> MultiModalEmbeddings | None:
1613
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
1614
1615
1616
1617
        if not mm_input_by_modality:
            return None

        # The result multimodal_embeddings is tuple of tensors, with each
1618
        # tensor corresponding to a multimodal data item (image or video).
1619
1620
1621
1622
1623
1624
1625
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
1626
1627
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
1628
1629
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
1630
                multimodal_embeddings += tuple(video_embeddings)
1631
1632
1633
1634
1635
1636
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1637
1638
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1639
        **kwargs: object,
1640
    ) -> torch.Tensor | IntermediateTensors:
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        """Run forward pass for GLM-4V.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for GLM-4V
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
1651
1652
1653
1654
            intermediate_tensors: Optional intermediate tensors for pipeline
                parallelism.
            inputs_embeds: Optional pre-computed input embeddings.
            **kwargs: Additional keyword arguments.
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
        """
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1670
    ) -> torch.Tensor | None:
1671
        return self.language_model.compute_logits(hidden_states)
1672

1673
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1674
1675
1676
1677
1678
1679
1680
1681
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
Jee Jee Li's avatar
Jee Jee Li committed
1682
            language_model="language_model.model",
1683
1684
1685
            connector="visual.merger.",
            tower_model="visual.",
        )
Jee Jee Li's avatar
Jee Jee Li committed
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704


@MULTIMODAL_REGISTRY.register_processor(
    Glm4vMultiModalProcessor,
    info=Glm4vProcessingInfo,
    dummy_inputs=Glm4vDummyInputsBuilder,
)
class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }