chatglm.py 13.3 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
3
# coding=utf-8
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
from typing import List, Optional
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
6
7
8
9
10

import torch
from torch import nn
from torch.nn import LayerNorm

11
from vllm.attention import Attention, AttentionMetadata
12
from vllm.config import LoRAConfig
13
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
14
from vllm.model_executor.layers.layernorm import RMSNorm
15
16
17
18
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.rotary_embedding import get_rope
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
21
from vllm.model_executor.layers.sampler import Sampler
22
from vllm.model_executor.layers.vocab_parallel_embedding import (
23
    ParallelLMHead, VocabParallelEmbedding)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
24
from vllm.model_executor.parallel_utils.parallel_state import (
25
    get_tensor_model_parallel_world_size)
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
28
29
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
30
31
32
33
34
from vllm.transformers_utils.configs import ChatGLMConfig


class GLMAttention(nn.Module):

35
36
37
38
39
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
40
41
42
43
44
45
46
47
48
49
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
50
51
52
53
54
55
56
57
58
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
59
60
61
62
63
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

64
65
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
66
            self.head_dim,
67
68
69
70
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
            linear_method=linear_method,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
71
72
73
74
75
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
76
            linear_method=linear_method,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
77
78
        )

79
80
81
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
83
84
            self.head_dim,
            rotary_dim=self.head_dim // 2,
85
86
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
87
88
            is_neox_style=False,
        )
89
        self.attn = Attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
94
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
95
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
100
101
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
102
103
104
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
106
107
108
109
        context_layer = self.attn(
            q,
            k,
            v,
110
111
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
112
113
114
115
116
117
118
119
120
121
122
123
124
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

125
126
127
128
129
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
130
131
132
133
134
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
135
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
136
            config.hidden_size,
137
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
138
            bias=config.add_bias_linear,
139
            linear_method=linear_method,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
140
141
142
143
144
145
146
147
148
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
149
            linear_method=linear_method,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
        config,
171
        linear_method: Optional[LinearMethodBase] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
172
173
174
175
176
177
178
179
180
181
182
183
184
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
185
        self.self_attention = GLMAttention(config, linear_method)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
186
187
188
189
190
191
192
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
193
        self.mlp = GLMMLP(config, linear_method)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
194
195
196
197
198

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
199
200
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
201
202
203
204
205
206
207
208
209
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
210
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

238
239
240
241
242
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
243
244
245
246
247
248
249
250
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        self.layers = nn.ModuleList(
251
            [GLMBlock(config, linear_method) for i in range(self.num_layers)])
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
252
253
254
255
256
257
258
259
260
261
262

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
263
264
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
265
266
267
268
269
270
271
    ) -> torch.Tensor:
        for i in range(self.num_layers):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                kv_cache=kv_caches[i],
272
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
273
274
275
276
277
278
279
280
281
282
            )
        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

283
284
285
286
287
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
288
289
290
291
292
293
294
295
        super().__init__()

        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
                                                config.hidden_size)

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
296
        self.encoder = GLMTransformer(config, linear_method)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
297

298
299
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                           config.hidden_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
300
301
302
303
304

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
305
306
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
307
    ) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
308
309
310
311
312
313
314
        inputs_embeds = self.embedding(input_ids)

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
            position_ids=position_ids,
            kv_caches=kv_caches,
315
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
316
317
318
319
320
        )
        return hidden_states


class ChatGLMForCausalLM(nn.Module):
321
322
323
324
325
326
327
328
329
330
331
332
333
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
334

335
336
337
338
    def __init__(
        self,
        config: ChatGLMConfig,
        linear_method: Optional[LinearMethodBase] = None,
339
        lora_config: Optional[LoRAConfig] = None,
340
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
341
342
        super().__init__()
        self.config: ChatGLMConfig = config
343
344
        self.linear_method = linear_method
        self.transformer = ChatGLMModel(config, linear_method)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
345
        self.lm_head_weight = self.transformer.output_layer.weight
346
347
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
348
349
350
351
352

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
353
354
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
355
    ) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
356
        hidden_states = self.transformer(input_ids, positions, kv_caches,
357
                                         attn_metadata)
358
359
        return hidden_states

360
361
362
363
364
365
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

366
367
    def sample(
        self,
368
        logits: torch.Tensor,
369
        sampling_metadata: SamplingMetadata,
370
    ) -> Optional[SamplerOutput]:
371
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
372
373
        return next_tokens

374
375
376
377
378
379
    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
380
381
        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, load_format, revision):
382
383
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
384
385
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
386
387
388
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
389
390
391
392
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)