telechat2.py 8.63 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22
from typing import Iterable, Set, Tuple
23

zhuwenwen's avatar
zhuwenwen committed
24
25
import os
import re
26
import torch
27
import torch.nn as nn
28
29
30
31
32

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

33
from .llama import LlamaDecoderLayer
34
35
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
                    is_pp_missing_parameter)
zhuwenwen's avatar
zhuwenwen committed
36
37
38
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


class TeleChat2Model(LlamaModel):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        # 1. Initialize the LlamaModel with bias
        vllm_config.model_config.hf_config.bias = True
        vllm_config.model_config.hf_config.mlp_bias = True
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        # 2. Remove the bias from the qkv_proj and gate_up_proj based on config
        # Telechat2's gate_up_proj and qkv_proj don't have bias
        # see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566
        for layer in self.layers:
            if not isinstance(layer, PPMissingLayer):
                layer.self_attn.qkv_proj.bias = None
                layer.self_attn.qkv_proj.skip_bias_add = True
                layer.mlp.gate_up_proj.bias = None
                layer.mlp.gate_up_proj.skip_bias_add = True
zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
61
62
63
64
65
66
67
68
                
        self.quant_method = None
        if vllm_config.quant_config is not None:
            self.quant_method=vllm_config.quant_config.get_name()
            self.quant_config=vllm_config.quant_config
            
        self.tritonsingleton= W8a8GetCacheJSON()      
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
69
70
71
72
73
74
75
76
77
78
79
80

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            ('gate_up_proj', 'gate_proj', 0),
            ('gate_up_proj', 'up_proj', 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        total_num_heads = self.config.n_head
        head_dim = self.config.hidden_size // total_num_heads
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
81
82
            current_count = loaded_weight.current_count 
            total_count = loaded_weight.total_count
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            if "self_attn.key_value" in name:
                k_weight = []
                v_weight = []
                for i in range(total_num_heads):
                    start = i * head_dim * 2
                    k_weight.append(loaded_weight[start:start + head_dim, :])
                    v_weight.append(loaded_weight[start + head_dim:start +
                                                  2 * head_dim:])
                k_weight = torch.cat(k_weight, dim=0)
                v_weight = torch.cat(v_weight, dim=0)
                name = name.replace("key_value", "qkv_proj")
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, k_weight, "k")
                weight_loader(param, v_weight, "v")
            elif "query" in name:
                name = name.replace("query", "qkv_proj")
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, "q")
            else:
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
            
            # for layername, weight in params_dict.items():
            # for layername in loaded_params:
            for layername in params_dict.keys():
                weight = params_dict[layername]
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
                    
                matches = re.findall(combined_words, layername)
                
                if matches:         
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
                        
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                                 
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
167
168
169
170
171
        return loaded_params


class TeleChat2ForCausalLM(LlamaForCausalLM):

172
173
174
175
176
177
178
179
180
181
182
183
184
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "transformer.": "model.",
        },
        orig_to_new_substr={
            ".h.": ".layers.",
            ".self_attention.": ".self_attn.",
            ".word_embeddings.": ".embed_tokens.",
            ".dense.": ".o_proj.",
            ".ln_f.": ".norm.",
        },
    )

185
186
187
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
188
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
189
190
191
192
193
194
195
196
197
198
        return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:

        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
zhuwenwen's avatar
zhuwenwen committed
199
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)