chatglm.py 15.7 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
3
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
from typing import Iterable, List, Optional, Tuple
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
6
7
8
9

import torch
from torch import nn
from torch.nn import LayerNorm
zhuwenwen's avatar
zhuwenwen committed
10
import os
zhuwenwen's avatar
zhuwenwen committed
11
import re
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
12

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.config import CacheConfig, LoRAConfig
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
17
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
20
                                               QKVParallelLinear,
                                               RowParallelLinear)
21
from vllm.model_executor.layers.logits_processor import LogitsProcessor
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
24
from vllm.model_executor.layers.rotary_embedding import get_rope
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
25
from vllm.model_executor.layers.sampler import Sampler
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
    ParallelLMHead, VocabParallelEmbedding)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.sampling_metadata import SamplingMetadata
30
from vllm.sequence import SamplerOutput
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
31
from vllm.transformers_utils.configs import ChatGLMConfig
32

zhuwenwen's avatar
zhuwenwen committed
33
from vllm import _custom_ops as ops
34
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
35
36
37
38


class GLMAttention(nn.Module):

39
40
41
    def __init__(
        self,
        config,
42
        cache_config: Optional[CacheConfig] = None,
43
        quant_config: Optional[QuantizationConfig] = None,
44
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
45
46
47
48
49
50
51
52
53
54
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
55
56
57
58
59
60
61
62
63
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
64
65
66
67
68
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

69
70
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
71
            self.head_dim,
72
73
74
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
75
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
76
77
78
79
80
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
81
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
82
83
        )

84
85
86
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
88
89
            self.head_dim,
            rotary_dim=self.head_dim // 2,
90
91
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
92
93
            is_neox_style=False,
        )
94
95
96
97
98
99
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
100
101
102
103
104

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
105
106
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
107
108
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
109
        if os.environ.get('FA_PAD') == '1':
110
            qkv = qkv[...,:-32]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
111
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
113
114
115
116
        context_layer = self.attn(
            q,
            k,
            v,
117
118
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
119
120
121
122
123
124
125
126
127
128
129
130
131
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

132
133
134
    def __init__(
        self,
        config,
135
        quant_config: Optional[QuantizationConfig] = None,
136
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
137
138
139
140
141
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
142
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
143
            config.hidden_size,
144
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
145
            bias=config.add_bias_linear,
146
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
147
148
149
150
151
152
153
154
155
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
156
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
        config,
178
        cache_config: Optional[CacheConfig] = None,
179
        quant_config: Optional[QuantizationConfig] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
180
181
182
183
184
185
186
187
188
189
190
191
192
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
193
        self.self_attention = GLMAttention(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
194
195
196
197
198
199
200
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
201
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
202
203
204
205
206

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
207
208
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
209
210
211
212
213
214
215
216
217
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
218
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

246
247
248
    def __init__(
        self,
        config,
249
        cache_config: Optional[CacheConfig] = None,
250
        quant_config: Optional[QuantizationConfig] = None,
251
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
252
253
254
255
256
257
258
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
259
260
261
262
        self.layers = nn.ModuleList([
            GLMBlock(config, cache_config, quant_config)
            for i in range(self.num_layers)
        ])
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
263
264
265
266
267
268
269
270
271
272
273

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
274
275
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
276
277
278
279
280
281
282
    ) -> torch.Tensor:
        for i in range(self.num_layers):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                kv_cache=kv_caches[i],
283
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
284
285
286
287
288
289
290
291
292
293
            )
        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

294
295
296
    def __init__(
        self,
        config,
297
        cache_config: Optional[CacheConfig] = None,
298
        quant_config: Optional[QuantizationConfig] = None,
299
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
300
301
302
303
304
305
306
307
        super().__init__()

        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
                                                config.hidden_size)

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
308
        self.encoder = GLMTransformer(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
309

310
311
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                           config.hidden_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
312
313
314
315
316

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
317
318
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
319
    ) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
320
321
322
323
324
325
326
        inputs_embeds = self.embedding(input_ids)

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
            position_ids=position_ids,
            kv_caches=kv_caches,
327
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
328
329
330
331
332
        )
        return hidden_states


class ChatGLMForCausalLM(nn.Module):
333
334
335
336
337
338
339
340
341
342
343
344
345
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
346

347
348
349
    def __init__(
        self,
        config: ChatGLMConfig,
350
        cache_config: Optional[CacheConfig] = None,
351
        quant_config: Optional[QuantizationConfig] = None,
352
        lora_config: Optional[LoRAConfig] = None,
353
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
354
355
        super().__init__()
        self.config: ChatGLMConfig = config
356
        self.quant_config = quant_config
357
358
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
359
        self.transformer = ChatGLMModel(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
360
        self.lm_head_weight = self.transformer.output_layer.weight
361
362
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
zhuwenwen's avatar
zhuwenwen committed
363
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
364
365
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
366
367
368
369
370

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
371
372
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
373
    ) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
374
        hidden_states = self.transformer(input_ids, positions, kv_caches,
375
                                         attn_metadata)
376
377
        return hidden_states

378
379
380
381
382
383
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

384
385
    def sample(
        self,
386
        logits: torch.Tensor,
387
        sampling_metadata: SamplingMetadata,
388
    ) -> Optional[SamplerOutput]:
389
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
390
391
        return next_tokens

392
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
393
        params_dict = dict(self.named_parameters(remove_duplicate=False))
394
        for name, loaded_weight in weights:
395
396
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
397
398
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
399
400
401
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
402
403
404
405
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
406
407
408
409
410
411
412
413
414
415
        
        if self.use_llama_nn:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
416
417
418
419
420
421
            lay_qkv_words = ["self_attention.query_key_value.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words)
            
zhuwenwen's avatar
zhuwenwen committed
422
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
423
424
425
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                    
zhuwenwen's avatar
zhuwenwen committed
426
                matches = re.findall(combined_words, layername)
427
428
429
430
                if matches:  
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
431
432
433
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
434
                                        
zhuwenwen's avatar
zhuwenwen committed
435
436
437
438
439
440
441
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)