internlm2.py 3.85 KB
Newer Older
Fengzhe Zhou's avatar
Fengzhe Zhou committed
1
# -*- coding: utf-8 -*-
Roy's avatar
Roy committed
2
from typing import Optional
Fengzhe Zhou's avatar
Fengzhe Zhou committed
3
4
5

import torch
from transformers import PretrainedConfig
Roy's avatar
Roy committed
6
from vllm.config import LoRAConfig
Fengzhe Zhou's avatar
Fengzhe Zhou committed
7

Roy's avatar
Roy committed
8
9
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.llama import LlamaForCausalLM
Fengzhe Zhou's avatar
Fengzhe Zhou committed
10
11
12
13
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)


Roy's avatar
Roy committed
14
class InternLM2ForCausalLM(LlamaForCausalLM):
Fengzhe Zhou's avatar
Fengzhe Zhou committed
15
16
17

    def __init__(
        self,
Roy's avatar
Roy committed
18
        config: Optional[PretrainedConfig] = None,
Fengzhe Zhou's avatar
Fengzhe Zhou committed
19
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
20
        lora_config: Optional[LoRAConfig] = None,
Fengzhe Zhou's avatar
Fengzhe Zhou committed
21
    ) -> None:
Roy's avatar
Roy committed
22
23
24
        super().__init__(config=config,
                         linear_method=linear_method,
                         lora_config=lora_config)
Fengzhe Zhou's avatar
Fengzhe Zhou committed
25
26
27
28
29
30
31
32
33
34
35

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w1", 0),
            ("gate_up_proj", "w3", 1),
        ]
Roy's avatar
Roy committed
36
37
38
39
40
41
42
43
44
45
46
        param_weight_map = [
            ("qkv_proj", "wqkv"),
            ("o_proj", "wo"),
            ("down_proj", "w2"),
            ("input_layernorm", "attention_norm"),
            ("post_attention_layernorm", "ffn_norm"),
            ("embed_tokens", "tok_embeddings"),
            (".self_attn.", ".attention."),
            ("mlp", "feed_forward"),
            ("lm_head", "output"),
        ]
Fengzhe Zhou's avatar
Fengzhe Zhou committed
47
48
49
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, load_format, revision):
Roy's avatar
Roy committed
50
51
52
            for (param_name, weight_name) in param_weight_map:
                name = name.replace(weight_name, param_name)

Fengzhe Zhou's avatar
Fengzhe Zhou committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Roy's avatar
Roy committed
71
                if "qkv_proj" in name:
Fengzhe Zhou's avatar
Fengzhe Zhou committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                    config = self.config
                    kv_groups = config.num_attention_heads // config.num_key_value_heads
                    head_dim = config.hidden_size // config.num_attention_heads
                    loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
                                                       head_dim,
                                                       loaded_weight.shape[-1])
                    wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
                                             dim=1)
                    wq = wq.reshape(-1, wq.shape[-1])
                    wk = wk.reshape(-1, wk.shape[-1])
                    wv = wv.reshape(-1, wv.shape[-1])
                    weight_loader = param.weight_loader
                    weight_loader(param, wq, 'q')
                    weight_loader(param, wk, 'k')
                    weight_loader(param, wv, 'v')
                else:
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)