chatglm.py 24.2 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
# coding=utf-8
# Adapted from
3
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
6
7
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
13
from torch import nn
from torch.nn import LayerNorm

14
from vllm.attention import Attention, AttentionMetadata
15
16
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
17
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
22
23
                                               QKVParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
26
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
27
from vllm.model_executor.layers.rotary_embedding import get_rope
28
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
34
35
36
37
38
39
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
40
41
from vllm.transformers_utils.configs import ChatGLMConfig

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from .interfaces import SupportsLoRA, SupportsMultiModal

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

    return MultiModalInputs({'pixel_values': pixel_values})


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def dummy_data_for_glmv(
    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
        return seq_data, None
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


152
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
153
154
155
156
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
157
        return inputs
158
159
160
161
162
163
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

164
165
    input_ids = inputs.get("prompt_token_ids")
    position_ids = inputs.get("position_ids")
166
167
168
169
170
171
172
173
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
174
175
                "image": inputs['multi_modal_data']["image"],
                "content": inputs['prompt']
176
177
178
179
180
181
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
182
        logger.error("Failed to process content (%s)", inputs['prompt'])
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    if position_ids is None:
        position_ids = list(range(len(input_ids)))
    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    new_position_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_position_ids.extend(
            list(range(final_processed_position, boi_position + 1)))
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        new_position_ids.extend([boi_position + 1] * image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])
    new_position_ids.extend(
        list(range(final_processed_position, len(input_ids))))

    assert len(new_input_ids) == len(new_position_ids)

217
218
219
    inputs["prompt_token_ids"] = new_input_ids
    inputs["position_ids"] = new_position_ids
    return inputs
220

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
221
222
223

class GLMAttention(nn.Module):

224
225
    def __init__(
        self,
226
        config,
227
        cache_config: Optional[CacheConfig] = None,
228
        quant_config: Optional[QuantizationConfig] = None,
229
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
230
231
232
233
234
235
236
237
238
239
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
240
241
242
243
244
245
246
247
248
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
249
250
251
252
253
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

254
255
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
256
            self.head_dim,
257
258
259
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
260
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
261
262
263
264
265
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
266
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
267
268
        )

269
270
271
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
272
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
273
274
            self.head_dim,
            rotary_dim=self.head_dim // 2,
275
276
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
277
278
            is_neox_style=False,
        )
279
280
281
282
283
284
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
285
286
287
288
289

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
290
291
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
292
293
294
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
296
297
298
299
        context_layer = self.attn(
            q,
            k,
            v,
300
301
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
302
303
304
305
306
307
308
309
310
311
312
313
314
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

315
316
    def __init__(
        self,
317
        config,
318
        quant_config: Optional[QuantizationConfig] = None,
319
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
320
321
322
323
324
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
325
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
326
            config.hidden_size,
327
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
328
            bias=config.add_bias_linear,
329
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
330
331
332
333
334
335
336
337
338
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
339
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
360
        config,
361
        cache_config: Optional[CacheConfig] = None,
362
        quant_config: Optional[QuantizationConfig] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
363
364
365
366
367
368
369
370
371
372
373
374
375
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
376
        self.self_attention = GLMAttention(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
377
378
379
380
381
382
383
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
384
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
385
386
387
388
389

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
390
391
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
392
393
394
395
396
397
398
399
400
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
401
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

429
430
    def __init__(
        self,
431
        config,
432
        cache_config: Optional[CacheConfig] = None,
433
        quant_config: Optional[QuantizationConfig] = None,
434
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
435
436
437
438
439
440
441
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
442
443
444
445
        self.layers = nn.ModuleList([
            GLMBlock(config, cache_config, quant_config)
            for i in range(self.num_layers)
        ])
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
446
447
448
449
450
451
452
453
454
455
456

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
457
458
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
459
    ) -> torch.Tensor:
460
        for i in range(self.num_layers):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
461
462
463
464
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
465
                kv_cache=kv_caches[i],
466
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
467
468
            )
        # Final layer norm.
469
        if self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
470
471
472
473
474
475
476
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

477
478
    def __init__(
        self,
479
        config,
480
        cache_config: Optional[CacheConfig] = None,
481
        quant_config: Optional[QuantizationConfig] = None,
482
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
483
484
        super().__init__()

485
486
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
487
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
488
489
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
490
491
492
493

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
494
        self.encoder = GLMTransformer(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
495

496
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
497
498
                                           config.hidden_size,
                                           quant_config=quant_config)
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
            self.vision = EVA2CLIPModel(self.config, quant_config)
        else:
            self.vision = None

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
                raise TypeError("""pixel_values must be a torch.Tensor 
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
522
523
524
525

    def forward(
        self,
        input_ids: torch.Tensor,
526
        positions: torch.Tensor,
527
528
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> torch.Tensor:

        inputs_embeds = self.embedding(input_ids)
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input["pixel_values"] is not None:
            pixel_values = image_input["pixel_values"].to(
                dtype=inputs_embeds.dtype)
            image_embeds = self.vision(pixel_values)

            boi_token_id = self.config.boi_token_id
            eoi_token_id = self.config.eoi_token_id

            inputs_embeds = merge_glm_vision_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                vision_embeddings=image_embeds,
                boi_token_id=boi_token_id,
                eoi_token_id=eoi_token_id)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
550
551
552
553

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
554
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
555
            kv_caches=kv_caches,
556
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
557
558
559
560
        )
        return hidden_states


561
562
563
564
565
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
566
567
568
569
570
571
572
573
574
575
576
577
578
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
579

580
581
582
    def __init__(
        self,
        config: ChatGLMConfig,
583
        multimodal_config: MultiModalConfig,
584
        cache_config: Optional[CacheConfig] = None,
585
        quant_config: Optional[QuantizationConfig] = None,
586
        lora_config: Optional[LoRAConfig] = None,
587
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
588
        super().__init__()
589
590
591

        self.config = config
        self.lora_config = lora_config
592
        self.multimodal_config = multimodal_config
593

594
        self.quant_config = quant_config
595
596
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
597
        self.transformer = ChatGLMModel(config, cache_config, quant_config)
598
599
600
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
601
        self.lm_head = self.transformer.output_layer
602
603
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
604

605
606
607
608
609
610
611
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
612
        hidden_states = self.transformer(input_ids, positions, kv_caches,
613
                                         attn_metadata, **kwargs)
614
615
        return hidden_states

616
617
618
619
620
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
621
        logits = self.logits_processor(self.lm_head, hidden_states,
622
623
624
                                       sampling_metadata)
        return logits

625
626
    def sample(
        self,
627
        logits: torch.Tensor,
628
        sampling_metadata: SamplingMetadata,
629
    ) -> Optional[SamplerOutput]:
630
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
631
632
        return next_tokens

633
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
634
635
636
637
638
639
640
641
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }

642
        params_dict = dict(self.named_parameters(remove_duplicate=False))
643
        for name, loaded_weight in weights:
644
645
646
647
648
649
650
651
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
652
653
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
654
655
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
656
657
658
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
659
660
661
662
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
663
664
665
666
667
668
669
670
671

        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)