chatglm.py 18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
2
# Adapted from
3
4
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
from typing import Iterable, Optional, Set, Tuple, Union
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
6
7
8
9
10

import torch
from torch import nn
from torch.nn import LayerNorm

11
from vllm.attention import Attention
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
14
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
15
from vllm.model_executor.layers.layernorm import RMSNorm
16
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
17
18
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.quantization import QuantizationConfig
21
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
22
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
23
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
    ParallelLMHead, VocabParallelEmbedding)
25
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.sequence import IntermediateTensors
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
28
29
from vllm.transformers_utils.configs import ChatGLMConfig

30
from .interfaces import SupportsLoRA, SupportsPP
31
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
32
                    make_empty_intermediate_tensors_factory, make_layers,
33
                    maybe_prefix)
34

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
35
36
37

class GLMAttention(nn.Module):

38
39
    def __init__(
        self,
40
        config: ChatGLMConfig,
41
        cache_config: Optional[CacheConfig] = None,
42
        quant_config: Optional[QuantizationConfig] = None,
43
        prefix: str = "",
44
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
45
46
47
48
49
50
51
52
53
54
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
55
56
57
58
59
60
61
62
63
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
64
65
66
67
68
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

69
70
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
71
            self.head_dim,
72
73
74
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
75
            quant_config=quant_config,
76
            prefix=f"{prefix}.query_key_value",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
77
78
79
80
81
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
82
            quant_config=quant_config,
83
            prefix=f"{prefix}.dense",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
84
85
        )

86
87
88
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
89
90
91
        # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
        # which is equivalent to is_neox_style=True
        is_neox_style = not config.original_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
93
94
            self.head_dim,
            rotary_dim=self.head_dim // 2,
95
96
            max_position=max_positions,
            base=10000 * rope_ratio,
97
            is_neox_style=is_neox_style,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
98
        )
99
100
101
102
103
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
104
105
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
106
107
108
109
110
111
112
113

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        q, k = self.rotary_emb(position_ids, q, k)
115
        context_layer = self.attn(q, k, v)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
116
117
118
119
120
121
122
123
124
125
126
127
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

128
129
    def __init__(
        self,
130
        config: ChatGLMConfig,
131
        quant_config: Optional[QuantizationConfig] = None,
132
        prefix: str = "",
133
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
134
135
136
137
138
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
139
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
140
            config.hidden_size,
141
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
142
            bias=config.add_bias_linear,
143
            quant_config=quant_config,
144
            prefix=f"{prefix}.dense_h_to_4h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
145
146
147
148
149
150
151
152
153
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.dense_4h_to_h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
176
        config: ChatGLMConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
180
181
182
183
184
185
186
187
188
189
190
191
192
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
193
194
195
196
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
197
198
199
200
201
202
203
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
204
        self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

245
246
    def __init__(
        self,
247
        config: ChatGLMConfig,
248
        cache_config: Optional[CacheConfig] = None,
249
        quant_config: Optional[QuantizationConfig] = None,
250
        prefix: str = "",
251
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
252
253
254
255
256
257
258
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
259
260
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
261
262
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
263
264
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
265
266
267
268
269
270
271

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

272
273
274
275
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
276
277
278
279
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
280
    ) -> Union[torch.Tensor, IntermediateTensors]:
281
282
283
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states=hidden_states,
                                  position_ids=position_ids)
284
285
286
287

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
288
        # Final layer norm.
289
        if self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
290
291
292
293
294
295
296
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

297
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
298
299
        super().__init__()

300
301
302
303
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

304
305
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
306
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
307
                                                config.hidden_size,
308
309
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.embedding")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
310
311
312
313

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
314
315
316
317
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
318

319
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
320
                                           config.hidden_size,
321
322
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
323

324
325
326
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

327
328
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)
329

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
330
331
332
    def forward(
        self,
        input_ids: torch.Tensor,
333
334
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
335
        inputs_embeds: Optional[torch.Tensor] = None,
336
        **kwargs: object,
337
338
339
340
341
342
343
344
345
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
346

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
347
348
        # Run encoder.
        hidden_states = self.encoder(
349
            hidden_states=hidden_states,
350
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
351
        )
352

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
353
354
        return hidden_states

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
            ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "rotary_pos_emb.inv_freq" in name:
                    continue
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
393

394
class ChatGLMBaseModel(nn.Module):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
395

396
397
398
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".word_embeddings": ""}, )

399
400
401
402
403
404
405
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[ChatGLMModel] = ChatGLMModel,
    ) -> None:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
406
        super().__init__()
407
408
409
410
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
411
412
        self.config = config
        self.lora_config = lora_config
413
        self.multimodal_config = multimodal_config
414

415
        self.quant_config = quant_config
416
417
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
418
419
420
        self.transformer = transformer_type(vllm_config=vllm_config,
                                            prefix=maybe_prefix(
                                                prefix, "transformer"))
421
422
423
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
424
        self.lm_head = self.transformer.output_layer
425
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
Joe Runde's avatar
Joe Runde committed
426
        self.sampler = get_sampler()
427
428
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
429

430
431
432
433
434
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
435
        logits = self.logits_processor(self.lm_head, hidden_states,
436
437
438
                                       sampling_metadata)
        return logits

439
440
    def sample(
        self,
441
        logits: torch.Tensor,
442
        sampling_metadata: SamplingMetadata,
443
    ) -> Optional[SamplerOutput]:
444
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
445
446
        return next_tokens

447
448
449
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
450
451


452
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
453
454
455
456
457
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }

458
459
460
461
462
463
464
465
466
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        if hasattr(config, "vision_config"):
            hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
            raise RuntimeError(
                "The configuration of this model indicates that it supports "
                "vision inputs, but you instantiated the text-only version "
                "of this model. Please use the vision model by setting "
                f"`--hf-overrides {hf_overrides!r}`")
467

468
        super().__init__(vllm_config=vllm_config, prefix=prefix)
469

470
    def forward(
471
472
        self,
        input_ids: torch.Tensor,
473
474
475
476
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
477
478
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
479
        return hidden_states