chatglm.py 26.8 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
# coding=utf-8
# Adapted from
3
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
6
7
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
from torch import nn
from torch.nn import LayerNorm
zhuwenwen's avatar
zhuwenwen committed
13
import os
zhuwenwen's avatar
zhuwenwen committed
14
import re
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
15

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
20
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
21
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
22
from vllm.model_executor.layers.layernorm import RMSNorm
23
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
24
25
                                               QKVParallelLinear,
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
28
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
29
from vllm.model_executor.layers.rotary_embedding import get_rope
30
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
    ParallelLMHead, VocabParallelEmbedding)
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
35
from vllm.model_executor.sampling_metadata import SamplingMetadata
36
37
38
39
40
41
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
42
from vllm.transformers_utils.configs import ChatGLMConfig
43

zhuwenwen's avatar
zhuwenwen committed
44
from vllm import _custom_ops as ops
45
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from .interfaces import SupportsLoRA, SupportsMultiModal

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

    return MultiModalInputs({'pixel_values': pixel_values})


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def dummy_data_for_glmv(
    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
        return seq_data, None
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        return llm_inputs
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    input_ids = llm_inputs.get("prompt_token_ids")
    position_ids = llm_inputs.get("position_ids")
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": llm_inputs['multi_modal_data']["image"],
                "content": llm_inputs['prompt']
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process content (%s)", llm_inputs['prompt'])
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    if position_ids is None:
        position_ids = list(range(len(input_ids)))
    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    new_position_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_position_ids.extend(
            list(range(final_processed_position, boi_position + 1)))
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        new_position_ids.extend([boi_position + 1] * image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])
    new_position_ids.extend(
        list(range(final_processed_position, len(input_ids))))

    assert len(new_input_ids) == len(new_position_ids)

    llm_inputs["prompt_token_ids"] = new_input_ids
    llm_inputs["position_ids"] = new_position_ids
    return llm_inputs
225

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
226
227
228

class GLMAttention(nn.Module):

229
230
231
    def __init__(
        self,
        config,
232
        cache_config: Optional[CacheConfig] = None,
233
        quant_config: Optional[QuantizationConfig] = None,
234
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
235
236
237
238
239
240
241
242
243
244
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
245
246
247
248
249
250
251
252
253
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
254
255
256
257
258
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

259
260
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
261
            self.head_dim,
262
263
264
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
265
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
266
267
268
269
270
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
271
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
272
273
        )

274
275
276
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
278
279
            self.head_dim,
            rotary_dim=self.head_dim // 2,
280
281
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
282
283
            is_neox_style=False,
        )
284
285
286
287
288
289
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
290
291
292
293
294
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
295
296
297
298
299

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
300
301
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
302
303
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
304
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
305
            qkv = qkv[...,:-32]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
306
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
307
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
308
309
310
311
        context_layer = self.attn(
            q,
            k,
            v,
312
313
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
314
315
316
317
318
319
320
321
322
323
324
325
326
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

327
328
329
    def __init__(
        self,
        config,
330
        quant_config: Optional[QuantizationConfig] = None,
331
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
332
333
334
335
336
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
337
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
338
            config.hidden_size,
339
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
340
            bias=config.add_bias_linear,
341
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
342
343
344
345
346
347
348
349
350
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
351
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
        config,
373
        cache_config: Optional[CacheConfig] = None,
374
        quant_config: Optional[QuantizationConfig] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
375
376
377
378
379
380
381
382
383
384
385
386
387
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
388
        self.self_attention = GLMAttention(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
389
390
391
392
393
394
395
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
396
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
397
398
399
400
401

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
402
403
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
404
405
406
407
408
409
410
411
412
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
413
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

441
442
443
    def __init__(
        self,
        config,
444
        cache_config: Optional[CacheConfig] = None,
445
        quant_config: Optional[QuantizationConfig] = None,
446
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
447
448
449
450
451
452
453
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
454
455
456
457
        self.layers = nn.ModuleList([
            GLMBlock(config, cache_config, quant_config)
            for i in range(self.num_layers)
        ])
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
458
459
460
461
462
463
464
465
466
467
468

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
469
470
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
471
472
473
474
475
476
477
    ) -> torch.Tensor:
        for i in range(self.num_layers):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                kv_cache=kv_caches[i],
478
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
479
480
481
482
483
484
485
486
487
488
            )
        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

489
490
491
    def __init__(
        self,
        config,
492
        cache_config: Optional[CacheConfig] = None,
493
        quant_config: Optional[QuantizationConfig] = None,
494
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
495
        super().__init__()
496
497
        
        self.config = config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
498
499

        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
500
501
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
502
503
504
505

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
506
        self.encoder = GLMTransformer(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
507

508
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
509
510
                                           config.hidden_size,
                                           quant_config=quant_config)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        
        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
            self.vision = EVA2CLIPModel(self.config, quant_config)
        else:
            self.vision = None
        
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
                raise TypeError("""pixel_values must be a torch.Tensor 
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
534
535
536
537

    def forward(
        self,
        input_ids: torch.Tensor,
538
        positions: torch.Tensor,
539
540
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
541
542
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
543
    ) -> torch.Tensor:
544
        
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
545
        inputs_embeds = self.embedding(input_ids)
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input["pixel_values"] is not None:
            pixel_values = image_input["pixel_values"].to(
                dtype=inputs_embeds.dtype)
            image_embeds = self.vision(pixel_values)

            boi_token_id = self.config.boi_token_id
            eoi_token_id = self.config.eoi_token_id

            inputs_embeds = merge_glm_vision_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                vision_embeddings=image_embeds,
                boi_token_id=boi_token_id,
                eoi_token_id=eoi_token_id)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
562
563
564
565

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
566
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
567
            kv_caches=kv_caches,
568
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
569
570
571
572
        )
        return hidden_states


573
574
575
576
577
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
578
579
580
581
582
583
584
585
586
587
588
589
590
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
591

592
593
594
    def __init__(
        self,
        config: ChatGLMConfig,
595
        multimodal_config: MultiModalConfig,
596
        cache_config: Optional[CacheConfig] = None,
597
        quant_config: Optional[QuantizationConfig] = None,
598
        lora_config: Optional[LoRAConfig] = None,
599
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
600
        super().__init__()
601
602
603

        self.config = config
        self.lora_config = lora_config
604
        self.multimodal_config = multimodal_config
605

606
        self.quant_config = quant_config
607
608
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
609
        self.transformer = ChatGLMModel(config, cache_config, quant_config)
610
611
612
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
613
        self.lm_head = self.transformer.output_layer
614
615
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
616
617
618
619
620
621
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

zhuwenwen's avatar
zhuwenwen committed
622
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
623
624
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
625

626
627
628
629
630
631
632
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
633
        hidden_states = self.transformer(input_ids, positions, kv_caches,
634
                                         attn_metadata, **kwargs)
635
636
        return hidden_states

637
638
639
640
641
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
642
        logits = self.logits_processor(self.lm_head, hidden_states,
643
644
645
                                       sampling_metadata)
        return logits

646
647
    def sample(
        self,
648
        logits: torch.Tensor,
649
        sampling_metadata: SamplingMetadata,
650
    ) -> Optional[SamplerOutput]:
651
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
652
653
        return next_tokens

654
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
655
656
657
658
659
660
661
662
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }
        
663
        params_dict = dict(self.named_parameters(remove_duplicate=False))
664
        for name, loaded_weight in weights:
665
666
667
668
669
670
671
672
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
673
674
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
675
676
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
677
678
679
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
680
681
682
683
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
684
685
686
687
688
689
690
691
692
            
        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)
zhuwenwen's avatar
zhuwenwen committed
693
        
694
        if self.use_llama_nn and self.quant_method is None:
zhuwenwen's avatar
zhuwenwen committed
695
696
697
698
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
699
700
                "mlp.dense_4h_to_h.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
701
702
703
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
704
705
706
707
708
709
            lay_qkv_words = ["self_attention.query_key_value.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words)
            
zhuwenwen's avatar
zhuwenwen committed
710
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
711
712
713
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                    
zhuwenwen's avatar
zhuwenwen committed
714
                matches = re.findall(combined_words, layername)
715
716
717
718
                if matches:  
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
719
720
721
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
722
                                        
zhuwenwen's avatar
zhuwenwen committed
723
724
725
726
727
728
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
729
                    weight.data=weight.data.reshape(ori_shape[1], -1)