chatglm.py 30.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
3
# Adapted from
4
5
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
6
from argparse import Namespace
7
8
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
                    TypedDict, Union)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
9
10
11
12

import torch
from torch import nn
from torch.nn import LayerNorm
13
14
15
16
17
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
18

19
from vllm.attention import Attention, AttentionMetadata
20
from vllm.config import CacheConfig, VllmConfig
21
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
22
from vllm.logger import init_logger
23
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
26
27
                                               QKVParallelLinear,
                                               RowParallelLinear)
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
31
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
    ParallelLMHead, VocabParallelEmbedding)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
36
from vllm.model_executor.models.module_mapping import MultiModelKeys
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, BatchFeature,
                                        BoundPromptReplacement,
                                        MultiModalFieldConfig,
                                        PlaceholderFeaturesInfo,
                                        PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
49
50
from vllm.transformers_utils.configs import ChatGLMConfig

51
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
52
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
53
                    make_empty_intermediate_tensors_factory, make_layers,
54
                    maybe_prefix, merge_multimodal_embeddings)
55
56
57

logger = init_logger(__name__)

58
IMAGE_TOKEN_ID = 151329
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

def build_normalization_transform(image_size: int) -> transforms.Compose:
    """
    Build a normalization transform which can be applied to one or
    more input images from which we want to extract visual features.

    Args:
        image_size: size of the image to be processed for visual embeddings.
    
    Returns:
        Callable transform for normalizing and resizing one RGB image.
    """

    return transforms.Compose([
        transforms.Resize(
            (image_size, image_size),
            interpolation=InterpolationMode.BICUBIC,
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073),
            (0.26862954, 0.26130258, 0.27577711),
        ),
    ])
84
85


86
87
def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
88
89
90
91
92
93
94


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


95
96
97
98
99
100
101
102
103
104
105
106
107
class GLM4VProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.

    """

    def __init__(
        self,
        config: ChatGLMConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()
108

109
110
        self.config = config
        self.tokenizer = tokenizer
111

112
113
114
115
116
        if hasattr(self.config, "vision_config"):
            self.image_transform = build_normalization_transform(
                config.vision_config["image_size"])
        else:
            self.image_transform = None
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    def __call__(
        self,
        text: Optional[Union[TextInput, list[TextInput]]] = None,
        images: Optional[Union[ImageInput, list[ImageInput]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]
        text_inputs = self.tokenizer(text)
        if len(images) == 0:
            image_inputs = {}
        else:
            if self.image_transform is None:
                raise ValueError("This model does not support image inputs")

            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )
149
150


151
class GLM4VProcessingInfo(BaseProcessingInfo):
152

153
154
155
    def __init__(self, ctx):
        super().__init__(ctx)
        self._pre_calculate()
156

157
158
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}
159

160
161
162
163
164
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
165

166
        return {"image": self.image_token_num + 2}
167

168
169
170
171
172
    def _pre_calculate(self):
        hf_config = self.get_hf_config()
        vision_config = hf_config.vision_config
        self.image_token_num = calculate_image_placeholder(vision_config)
        self.image_size = vision_config["image_size"]
173

174
175
    def get_num_image_tokens(self) -> int:
        return self.image_token_num + 2
176

177
    def get_image_size(self) -> ImageSize:
178

179
        return ImageSize(height=self.image_size, width=self.image_size)
180

181
182
183
184
185
    def get_hf_processor(self) -> GLM4VProcessor:
        return GLM4VProcessor(
            self.get_hf_config(),
            self.get_tokenizer(),
        )
186

187

188
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
189

190
191
192
193
194
195
196
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
        target_width, target_height = self.info.get_image_size()
197

198
199
200
201
202
203
204
205
206
207
208
        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }
        text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
        return ProcessorInputs(
            prompt_text=text,
            mm_data=mm_data,
        )
209
210


211
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
212

213
214
215
216
217
218
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:

        def get_replacement(item_idx: int):
            image_tokens = self.info.image_token_num
            return [IMAGE_TOKEN_ID] * image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[IMAGE_TOKEN_ID],
                replacement=get_replacement,
            ),
        ]
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
        mm_item_counts: Mapping[str, int],
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
        token_ids, text, placeholders = super()._apply_prompt_replacements(
            token_ids=token_ids,
            mm_prompt_repls=mm_prompt_repls,
            mm_item_counts=mm_item_counts,
        )
        hf_config = self.info.get_hf_config()
        boi_token_id = hf_config.boi_token_id
        eoi_token_id = hf_config.eoi_token_id
        placeholders = {
            modality: [
                PlaceholderFeaturesInfo(
                    modality=p.modality,
                    item_idx=p.item_idx,
                    start_idx=p.start_idx - 1,
                    tokens=[boi_token_id] + p.tokens + [eoi_token_id],
                ) for p in ps
            ]
            for modality, ps in placeholders.items()
        }
264

265
        return token_ids, text, placeholders
266

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
267
268
269

class GLMAttention(nn.Module):

270
271
    def __init__(
        self,
272
        config: ChatGLMConfig,
273
        cache_config: Optional[CacheConfig] = None,
274
        quant_config: Optional[QuantizationConfig] = None,
275
        prefix: str = "",
276
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
277
278
279
280
281
282
283
284
285
286
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
287
288
289
290
291
292
293
294
295
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
296
297
298
299
300
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

301
302
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
303
            self.head_dim,
304
305
306
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
307
            quant_config=quant_config,
308
            prefix=f"{prefix}.query_key_value",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
309
310
311
312
313
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
314
            quant_config=quant_config,
315
            prefix=f"{prefix}.dense",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
316
317
        )

318
319
320
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
321
322
323
        # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
        # which is equivalent to is_neox_style=True
        is_neox_style = not config.original_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
325
326
            self.head_dim,
            rotary_dim=self.head_dim // 2,
327
328
            max_position=max_positions,
            base=10000 * rope_ratio,
329
            is_neox_style=is_neox_style,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
330
        )
331
332
333
334
335
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
336
337
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
338
339
340
341
342

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
343
344
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
345
346
347
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
348
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
349
350
351
352
        context_layer = self.attn(
            q,
            k,
            v,
353
354
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
355
356
357
358
359
360
361
362
363
364
365
366
367
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

368
369
    def __init__(
        self,
370
        config: ChatGLMConfig,
371
        quant_config: Optional[QuantizationConfig] = None,
372
        prefix: str = "",
373
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
374
375
376
377
378
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
379
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
380
            config.hidden_size,
381
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
382
            bias=config.add_bias_linear,
383
            quant_config=quant_config,
384
            prefix=f"{prefix}.dense_h_to_4h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
385
386
387
388
389
390
391
392
393
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
394
            quant_config=quant_config,
395
            prefix=f"{prefix}.dense_4h_to_h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
416
        config: ChatGLMConfig,
417
        cache_config: Optional[CacheConfig] = None,
418
        quant_config: Optional[QuantizationConfig] = None,
419
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
420
421
422
423
424
425
426
427
428
429
430
431
432
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
433
434
435
436
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
437
438
439
440
441
442
443
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
444
        self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
445
446
447
448
449

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
450
451
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
452
453
454
455
456
457
458
459
460
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
461
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

489
490
    def __init__(
        self,
491
        config: ChatGLMConfig,
492
        cache_config: Optional[CacheConfig] = None,
493
        quant_config: Optional[QuantizationConfig] = None,
494
        prefix: str = "",
495
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
496
497
498
499
500
501
502
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
503
504
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
505
506
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
507
508
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
509
510
511
512
513
514
515

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

516
517
518
519
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
520
521
522
523
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
524
525
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
526
    ) -> torch.Tensor:
527
        for i in range(self.start_layer, self.end_layer):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
528
529
530
531
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
532
                kv_cache=kv_caches[i - self.start_layer],
533
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
534
535
            )
        # Final layer norm.
536
        if get_pp_group().is_last_rank and self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
537
538
539
540
541
542
543
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

544
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
545
546
        super().__init__()

547
548
549
550
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

551
552
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
553
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
554
                                                config.hidden_size,
555
556
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.embedding")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
557
558
559
560

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
561
562
563
564
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
565

566
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
567
                                           config.hidden_size,
568
569
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
570
571
572
573

        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
574
575
576
            self.vision = EVA2CLIPModel(self.config,
                                        quant_config,
                                        prefix=f"{prefix}.vision")
577
578
579
        else:
            self.vision = None

580
581
582
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

583
584
585
586
587
588
589
590
591
592
593
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
594
                raise TypeError("""pixel_values must be a torch.Tensor
595
596
597
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
598

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input["pixel_values"] is None:
            return None
        pixel_values = image_input["pixel_values"].to(
            dtype=self.config.torch_dtype)
        vision_embeddings = self.vision(pixel_values)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.embedding(input_ids)
        if multimodal_embeddings is not None:
615
            inputs_embeds = merge_multimodal_embeddings(
616
617
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
618
619
620
621
622
623
624
                multimodal_embeddings=multimodal_embeddings,
                placeholder_token_id=[
                    self.config.boi_token_id,
                    IMAGE_TOKEN_ID,
                    self.config.eoi_token_id,
                ],
            )
625
626
        return inputs_embeds

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
627
628
629
    def forward(
        self,
        input_ids: torch.Tensor,
630
        positions: torch.Tensor,
631
632
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
633
        intermediate_tensors: Optional[IntermediateTensors] = None,
634
        inputs_embeds: Optional[torch.Tensor] = None,
635
636
        **kwargs: object,
    ) -> torch.Tensor:
637
638
639

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
640
641
642
        if intermediate_tensors is not None:
            inputs_embeds = intermediate_tensors["hidden_states"]
        elif inputs_embeds is None:
643
644
645
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
646
647
648
        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
649
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
650
            kv_caches=kv_caches,
651
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
652
        )
653
654
655

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
656
657
        return hidden_states

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
            ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "rotary_pos_emb.inv_freq" in name:
                    continue
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
696

697
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
698

699
700
701
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".word_embeddings": ""}, )

702
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
703
        super().__init__()
704
705
706
707
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
708
709
        self.config = config
        self.lora_config = lora_config
710
        self.multimodal_config = multimodal_config
711

712
        self.quant_config = quant_config
713
714
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
715
716
717
        self.transformer = ChatGLMModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(
                                            prefix, "transformer"))
718
719
720
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
721
        self.lm_head = self.transformer.output_layer
722
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
Joe Runde's avatar
Joe Runde committed
723
        self.sampler = get_sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
724

725
726
727
728
729
730
731
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
732
        hidden_states = self.transformer(input_ids, positions, kv_caches,
733
734
                                         attn_metadata, intermediate_tensors,
                                         **kwargs)
735
736
        return hidden_states

737
738
739
740
741
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
742
        logits = self.logits_processor(self.lm_head, hidden_states,
743
744
745
                                       sampling_metadata)
        return logits

746
747
    def sample(
        self,
748
        logits: torch.Tensor,
749
        sampling_metadata: SamplingMetadata,
750
    ) -> Optional[SamplerOutput]:
751
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
752
753
        return next_tokens

754
755
756
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775


class ChatGLM(ChatGLMBaseModel):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


776
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
777

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
        "merged_proj": ["gate_proj", "dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
        # vision
        "fc1",
        "fc2",
        "merged_proj",
        "linear_proj"
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
            tower_model="transformer.vision.transformer")

808
809
810
811
812
813
814
815
816
817
818
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        return self.transformer.get_multimodal_embeddings(**kwargs)

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids,
                                                     multimodal_embeddings)

819

820
821
822
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
                                        info=GLM4VProcessingInfo,
                                        dummy_inputs=GLM4VDummyInputsBuilder)
823
824
825
826
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsMultiModal):
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
827
    # These will be updated when an instance class is selected
828
829
830
831
832
833
834
835
836
837
838
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        config = vllm_config.model_config.hf_config
839

840
        # Initialize VL
841
842
        if hasattr(config, "vision_config"):  # noqa: SIM108
            instance_cls = ChatGLMV
843
844
        # Initialize LLM
        else:
845
846
847
848
849
850
851
852
853
            instance_cls = ChatGLM

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.supported_lora_modules += instance_cls.supported_lora_modules
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)