glm4v.py 20.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

# Adapted from
5
# https://github.com/zai-org/CogAgent
6
"""Inference-only CogAgent model compatible with THUDM weights."""
7

8
from argparse import Namespace
9
from collections.abc import Iterator, Mapping, Sequence
10
from typing import Annotated, Literal
11

12
import numpy as np
13
14
15
import torch
from torch import nn
from torch.nn import LayerNorm
16
from transformers import BatchFeature
17
18

from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
21
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
22
from vllm.model_executor.layers.attention import MMEncoderAttention
23
from vllm.model_executor.layers.conv import Conv2dLayer
24
25
26
27
28
29
30
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
31
32
33
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
from vllm.multimodal.inputs import (
    MultiModalDataDict,
36
    MultiModalFeatureSpec,
37
38
39
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
40
from vllm.multimodal.parse import MultiModalDataItems
41
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
48
from vllm.sequence import IntermediateTensors
49
50
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.processors.glm4v import GLM4VProcessor
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
54
55
56
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
57
    SupportsMRoPE,
58
59
60
    SupportsMultiModal,
    SupportsPP,
)
61
62


63
64
65
66
67
68
69
70
class GLMVImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - c: Number of channels (3)
        - h: Height of image
        - w: Width of image
    """
71

72
73
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
74
75
76
77
78


class EVA2CLIPPatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
79
        self.proj = Conv2dLayer(
80
81
82
83
84
            config.in_channels,
            config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )
85
        self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
86
        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
87
88
89
90
91
92
93
94
95
96
97

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
98
        images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype)
99
100
101
102
103
104
105
106
107
108
109
110
        x = self.proj(images)
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embedding.weight.unsqueeze(0)
        return x


class EVA2CLIPAttention(nn.Module):
    def __init__(
        self,
        config,
111
        quant_config: QuantizationConfig | None = None,
112
        prefix: str = "",
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_rank = config.num_heads // self.tp_size
        self.head_dim = config.hidden_size // config.num_heads
        self.scale = self.head_dim**-0.5

        self.query_key_value = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            config.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.query_key_value",
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )

135
        self.attn = MMEncoderAttention(
136
137
138
            self.num_heads_per_rank,
            self.head_dim,
            self.scale,
139
            prefix=f"{prefix}.attn",
140
        )
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        self.output_dropout = torch.nn.Dropout(config.dropout_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        qkv, _ = self.query_key_value(x)  # B, L, 3 * H * D
        q, k, v = qkv.chunk(3, dim=-1)

        out = self.attn(q, k, v)
        output, _ = self.dense(out)
        output = self.output_dropout(output)
        return output


class EVA2CLIPMLP(nn.Module):
    def __init__(
        self,
        config,
157
        quant_config: QuantizationConfig | None = None,
158
        prefix: str = "",
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    ):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.activation_fn(x)
        x, _ = self.fc2(x)
        return x


class EVA2CLIPTransformerLayer(nn.Module):
    def __init__(
        self,
        config,
187
        quant_config: QuantizationConfig | None = None,
188
        prefix: str = "",
189
190
    ):
        super().__init__()
191
192
193
194
195
196
197
198
199
200
        self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = EVA2CLIPAttention(
            config, quant_config=quant_config, prefix=f"{prefix}.attention"
        )
        self.mlp = EVA2CLIPMLP(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
        self.post_attention_layernorm = LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
201
202
203

    def forward(self, hidden_states):
        attention_input = hidden_states
204
        attention_output = self.input_layernorm(self.attention(attention_input))
205
206
207
208
209
210
211
212
213
214
215
        hidden_states = attention_input + attention_output
        mlp_input = hidden_states
        mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
        output = mlp_input + mlp_output
        return output


class EVA2CLIPTransformer(nn.Module):
    def __init__(
        self,
        config,
216
        quant_config: QuantizationConfig | None = None,
217
        prefix: str = "",
218
219
    ):
        super().__init__()
220
221
222
223
224
225
226
227
228
229
        self.layers = nn.ModuleList(
            [
                EVA2CLIPTransformerLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
230
231
232
233
234
235
236
237
238
239
240
241

    def forward(self, hidden_states):
        for layer_module in self.layers:
            hidden_states = layer_module(hidden_states)
        return hidden_states


class EVA2CLIPGLU(nn.Module):
    def __init__(
        self,
        config,
        in_features,
242
        quant_config: QuantizationConfig | None = None,
243
        prefix: str = "",
244
245
246
247
248
249
250
251
    ):
        """
        The original implementation is the same as:
        ```python
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
252
            quant_config=quant_config,
253
254
255
256
257
258
        )

        self.gate_proj = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
259
            quant_config=quant_config,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        )
        ```
        ```
        gate_proj_output, _ = self.gate_proj(x)
        dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
        x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
        ```

        We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
        ```
        self.merged_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
            bias=False,
274
            quant_config=quant_config,
275
276
277
278
279
280
281
        )
        ```
        ```
        x, _ = self.merged_proj(x)
        ```
        """
        super().__init__()
282
283
284
285
286
287
288
        self.linear_proj = ReplicatedLinear(
            in_features,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
289
290
291
292
293
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.act1 = nn.GELU()
        self.act2 = SiluAndMul()

        self.merged_proj = MergedColumnParallelLinear(
294
295
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
296
297
            bias=False,
            quant_config=quant_config,
298
299
            prefix=f"{prefix}.merged_proj",
        )
300
301
302
303
304
305

        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
306
307
            prefix=f"{prefix}.dense_4h_to_h",
        )
308
309
310
311
312
313
314
315
316
317
318
319
320
321

    def forward(self, x):
        x, _ = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x, _ = self.merged_proj(x)
        x = self.act2(x)
        x, _ = self.dense_4h_to_h(x)
        return x


class EVA2CLIPModel(nn.Module):
    def __init__(
        self,
        config,
322
        quant_config: QuantizationConfig | None = None,
323
        prefix: str = "",
324
325
326
327
    ):
        super().__init__()
        vision_config = Namespace(**config.vision_config)
        self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config)
328
329
330
331
332
333
334
335
336
        self.transformer = EVA2CLIPTransformer(
            vision_config, quant_config=quant_config, prefix=f"{prefix}.transformer"
        )
        self.linear_proj = EVA2CLIPGLU(
            config,
            in_features=config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
337
        self.conv = Conv2dLayer(
338
339
340
341
342
            in_channels=vision_config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=2,
            stride=2,
        )
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.scaling_factor = vision_config.scaling_factor

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
        x = self.patch_embedding(images)
        x = self.transformer(x)
        x = x[:, 1:]

        b, s, h = x.shape
        grid_size = int(s**0.5)
        x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
        x = self.conv(x)

        x = x.flatten(2).transpose(1, 2)
        x = self.linear_proj(x)
        boi = self.boi.expand(x.shape[0], -1, -1)
        eoi = self.eoi.expand(x.shape[0], -1, -1)
        x = torch.cat((boi, x, eoi), dim=1)
        x = x / self.scaling_factor
        return x


class GLM4VModel(ChatGLMModel):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        quant_config = vllm_config.quant_config

381
382
383
        self.vision = EVA2CLIPModel(
            self.config, quant_config, prefix=f"{prefix}.vision"
        )
384
385
386
387
388
389


class GLM4VProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(ChatGLMConfig)

390
    def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
391
392
393
394
        config = self.get_hf_config()
        vision_config = config.vision_config
        image_size = vision_config["image_size"]

395
396
397
        return self.ctx.init_processor(
            GLM4VProcessor,
            tokenizer=self.get_tokenizer(),
398
            **{**kwargs, "image_size": image_size},
399
400
        )

401
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        return {"image": 1}

    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config

        image_size = vision_config["image_size"]
        patch_size = vision_config["patch_size"]
        grid_length = image_size // patch_size // 2
        return grid_length * grid_length

    def get_num_image_feature_tokens(self) -> int:
        # EVA2CLIPModel has embeddings for boi and eoi tokens as well
        return self.get_num_image_tokens() + 2


class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
419
420
421
422
423
424
425
426
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"

        return base_text * num_images

    def get_dummy_mm_data(
427
428
429
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
430
        mm_options: Mapping[str, BaseDummyOptions],
431
    ) -> MultiModalDataDict:
432
433
434
435
436
437
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config

        target_width = target_height = vision_config["image_size"]
        num_images = mm_counts.get("image", 0)

438
        image_overrides = mm_options.get("image")
439

440
        return {
441
442
443
444
445
446
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
447
448
449
450
        }


class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
451
    def _hf_processor_applies_updates(
452
453
454
455
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
456
        tokenization_kwargs: Mapping[str, object],
457
458
459
    ) -> bool:
        return False

460
461
462
463
464
465
466
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

467
    def _get_prompt_updates(
468
469
470
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
471
        out_mm_kwargs: MultiModalKwargsItems,
472
    ) -> Sequence[PromptUpdate]:
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        hf_config = self.info.get_hf_config()

        boi_token_id = hf_config.boi_token_id
        image_token_id = hf_config.pad_token_id
        eoi_token_id = hf_config.eoi_token_id

        def get_replacement(item_idx: int):
            num_image_tokens = self.info.get_num_image_tokens()
            image_tokens = [image_token_id] * num_image_tokens

            return [boi_token_id] + image_tokens + [eoi_token_id]

        return [
            PromptReplacement(
                modality="image",
                target=[boi_token_id, image_token_id, eoi_token_id],
                replacement=get_replacement,
            ),
        ]


494
495
496
497
498
@MULTIMODAL_REGISTRY.register_processor(
    GLM4VMultiModalProcessor,
    info=GLM4VProcessingInfo,
    dummy_inputs=GLM4VDummyInputsBuilder,
)
499
500
501
class GLM4VForCausalLM(
    ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
502
503
504
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
505
        "merged_proj": ["gate_proj", "dense_h_to_4h"],
506
507
508
509
510
511
512
513
514
    }

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
515
516
            tower_model="transformer.vision.transformer",
        )
517

518
    @classmethod
519
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
520
521
522
523
524
        if modality.startswith("image"):
            return "<|begin_of_image|><|endoftext|><|end_of_image|>"

        raise ValueError("Only image modality is supported")

525
526
527
528
529
530
531
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[GLM4VModel] = GLM4VModel,
    ) -> None:
532
533
534
535
536
537
538
539
540
541
        with self._mark_composite_model(
            vllm_config,
            language_targets=GLMTransformer,
            tower_targets={"image": EVA2CLIPModel},
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                transformer_type=transformer_type,
            )
542
543
544
545

        self.transformer: GLM4VModel

    def _parse_and_validate_image_input(
546
        self, **kwargs: object
547
    ) -> GLMVImagePixelInputs | None:
548
549
550
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is not None:
551
            expected_h = expected_w = self.config.vision_config["image_size"]
552
553
554
555
556
            return GLMVImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                resolve_bindings={"h": expected_h, "w": expected_w},
            )
557
558
559

        return None

560
    def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
561
        pixel_values = image_input["data"].to(dtype=self.config.dtype)
562
563
564

        return self.transformer.vision(pixel_values)

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int]]:
        hf_config = self.config
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, t, h // spatial_merge_size, w // spatial_merge_size
            else:
                # glm4v only supports image modality
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

580
    def get_mrope_input_positions(
581
        self,
582
        input_tokens: list[int],
583
        mm_features: list[MultiModalFeatureSpec],
584
585
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        st = 0
        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
                3, -1
            )
            llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            # EVA2CLIPModel has embeddings for boi and eoi tokens as well
            st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1

        if st < len(input_tokens):
            text_len = len(input_tokens) - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
611

612
        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
613
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
614
        return torch.from_numpy(llm_positions), mrope_position_delta
615

616
    embed_input_ids = SupportsMultiModal.embed_input_ids
617

618
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
619
620
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
621
            return []
622
623
624
625
626
627

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
628
        input_ids: torch.Tensor | None,
629
        positions: torch.Tensor,
630
631
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
632
        **kwargs: object,
633
    ) -> torch.Tensor | IntermediateTensors:
634
635
636
        if intermediate_tensors is not None:
            inputs_embeds = None

637
638
639
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
640

zhuwenwen's avatar
zhuwenwen committed
641
        return hidden_states