chatglm.py 25.1 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
# coding=utf-8
# Adapted from
3
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
6
7
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
13
from torch import nn
from torch.nn import LayerNorm

14
from vllm.attention import Attention, AttentionMetadata
15
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
16
17
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
                         token_inputs)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
21
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23
24
                                               QKVParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
34
35
36
37
38
39
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
40
41
from vllm.transformers_utils.configs import ChatGLMConfig

42
43
44
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

    return MultiModalInputs({'pixel_values': pixel_values})


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def dummy_data_for_glmv(
    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
        return seq_data, None
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


154
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
155
156
157
158
    multi_modal_data = inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return inputs

159
160
161
162
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
163
        return inputs
164
165
166
167
168
169
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

170
171
    input_ids = inputs["prompt_token_ids"]

172
173
174
175
176
177
178
179
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
180
181
                "image": multi_modal_data["image"],
                "content": inputs['prompt'],
182
183
184
185
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
186
187
            return_dict=True,
        ).data
188
    except Exception:
189
        logger.error("Failed to process content (%s)", inputs['prompt'])
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])

214
215
216
    prompt = inputs.get("prompt")
    if prompt is None:
        prompt = tokenizer.decode(new_input_ids)
217

218
219
220
221
222
    return token_inputs(
        prompt_token_ids=new_input_ids,
        prompt=prompt,
        multi_modal_data=multi_modal_data,
    )
223

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
224
225
226

class GLMAttention(nn.Module):

227
228
    def __init__(
        self,
229
        config: ChatGLMConfig,
230
        cache_config: Optional[CacheConfig] = None,
231
        quant_config: Optional[QuantizationConfig] = None,
232
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
233
234
235
236
237
238
239
240
241
242
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
243
244
245
246
247
248
249
250
251
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
252
253
254
255
256
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

257
258
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
259
            self.head_dim,
260
261
262
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
263
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
264
265
266
267
268
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
269
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
270
271
        )

272
273
274
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
275
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
276
277
            self.head_dim,
            rotary_dim=self.head_dim // 2,
278
279
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
280
281
            is_neox_style=False,
        )
282
283
284
285
286
287
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
288
289
290
291
292

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
293
294
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
295
296
297
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
298
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
299
300
301
302
        context_layer = self.attn(
            q,
            k,
            v,
303
304
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
305
306
307
308
309
310
311
312
313
314
315
316
317
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

318
319
    def __init__(
        self,
320
        config: ChatGLMConfig,
321
        quant_config: Optional[QuantizationConfig] = None,
322
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
323
324
325
326
327
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
328
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
329
            config.hidden_size,
330
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
331
            bias=config.add_bias_linear,
332
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
333
334
335
336
337
338
339
340
341
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
342
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
363
        config: ChatGLMConfig,
364
        cache_config: Optional[CacheConfig] = None,
365
        quant_config: Optional[QuantizationConfig] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
366
367
368
369
370
371
372
373
374
375
376
377
378
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
379
        self.self_attention = GLMAttention(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
380
381
382
383
384
385
386
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
387
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
388
389
390
391
392

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
393
394
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
395
396
397
398
399
400
401
402
403
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
404
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

432
433
    def __init__(
        self,
434
        config: ChatGLMConfig,
435
        cache_config: Optional[CacheConfig] = None,
436
        quant_config: Optional[QuantizationConfig] = None,
437
        prefix: str = "",
438
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
439
440
441
442
443
444
445
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
446
447
448
449
450
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
            lambda prefix: GLMBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
451
452
453
454
455
456
457

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

458
459
460
461
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
462
463
464
465
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
466
467
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
468
    ) -> torch.Tensor:
469
        for i in range(self.start_layer, self.end_layer):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
470
471
472
473
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
474
                kv_cache=kv_caches[i - self.start_layer],
475
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
476
477
            )
        # Final layer norm.
478
        if get_pp_group().is_last_rank and self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
479
480
481
482
483
484
485
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

486
487
    def __init__(
        self,
488
        config: ChatGLMConfig,
489
        cache_config: Optional[CacheConfig] = None,
490
        quant_config: Optional[QuantizationConfig] = None,
491
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
492
493
        super().__init__()

494
495
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
496
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
497
498
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
499
500
501
502

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
503
        self.encoder = GLMTransformer(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
504

505
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
506
507
                                           config.hidden_size,
                                           quant_config=quant_config)
508
509
510
511
512
513
514
515

        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
            self.vision = EVA2CLIPModel(self.config, quant_config)
        else:
            self.vision = None

516
517
518
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
                raise TypeError("""pixel_values must be a torch.Tensor 
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
534
535
536
537

    def forward(
        self,
        input_ids: torch.Tensor,
538
        positions: torch.Tensor,
539
540
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
541
542
543
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> torch.Tensor:
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        if intermediate_tensors is None:
            inputs_embeds = self.embedding(input_ids)
            image_input = self._parse_and_validate_image_input(**kwargs)

            if image_input["pixel_values"] is not None:
                pixel_values = image_input["pixel_values"].to(
                    dtype=inputs_embeds.dtype)
                image_embeds = self.vision(pixel_values)

                boi_token_id = self.config.boi_token_id
                eoi_token_id = self.config.eoi_token_id

                inputs_embeds = merge_glm_vision_embeddings(
                    input_ids=input_ids,
                    inputs_embeds=inputs_embeds,
                    vision_embeddings=image_embeds,
                    boi_token_id=boi_token_id,
                    eoi_token_id=eoi_token_id)
        else:
            inputs_embeds = intermediate_tensors["hidden_states"]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
564
565
566
567

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
568
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
569
            kv_caches=kv_caches,
570
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
571
        )
572
573
574

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
575
576
577
        return hidden_states


578
579
580
581
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
582
583
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                         SupportsMultiModal):
584
585
586
587
588
589
590
591
592
593
594
595
596
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
597

598
599
600
    def __init__(
        self,
        config: ChatGLMConfig,
601
        multimodal_config: MultiModalConfig,
602
        cache_config: Optional[CacheConfig] = None,
603
        quant_config: Optional[QuantizationConfig] = None,
604
        lora_config: Optional[LoRAConfig] = None,
605
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
606
        super().__init__()
607
608
609

        self.config = config
        self.lora_config = lora_config
610
        self.multimodal_config = multimodal_config
611

612
        self.quant_config = quant_config
613
614
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
615
        self.transformer = ChatGLMModel(config, cache_config, quant_config)
616
617
618
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
619
        self.lm_head = self.transformer.output_layer
620
621
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
622

623
624
625
626
627
628
629
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
630
        hidden_states = self.transformer(input_ids, positions, kv_caches,
631
632
                                         attn_metadata, intermediate_tensors,
                                         **kwargs)
633
634
        return hidden_states

635
636
637
638
639
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
640
        logits = self.logits_processor(self.lm_head, hidden_states,
641
642
643
                                       sampling_metadata)
        return logits

644
645
    def sample(
        self,
646
        logits: torch.Tensor,
647
        sampling_metadata: SamplingMetadata,
648
    ) -> Optional[SamplerOutput]:
649
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
650
651
        return next_tokens

652
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
653
654
655
656
657
658
659
660
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }

661
        params_dict = dict(self.named_parameters(remove_duplicate=False))
662
        for name, loaded_weight in weights:
663
664
665
666
667
668
669
670
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
671
672
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
673
674
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
675
676
677
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
678
679
            if is_pp_missing_parameter(name, self):
                continue
680
681
682
683
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
684
685
686
687
688
689
690
691
692

        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)