baichuan.py 20.7 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import Iterable, List, Optional, Tuple
codethazine's avatar
codethazine committed
23
24

import torch
25
from torch import nn
26
from transformers import PretrainedConfig
zhuwenwen's avatar
zhuwenwen committed
27
import os
zhuwenwen's avatar
zhuwenwen committed
28
import re
codethazine's avatar
codethazine committed
29

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
34
35
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
41
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
42
from vllm.model_executor.layers.rotary_embedding import get_rope
43
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    ParallelLMHead, VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
49

50
from .interfaces import SupportsLoRA
codethazine's avatar
codethazine committed
51

zhuwenwen's avatar
zhuwenwen committed
52
from vllm import _custom_ops as ops
53
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
codethazine's avatar
codethazine committed
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
89
        quant_config: Optional[QuantizationConfig] = None,
90
91
92
93
94
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
95
            quant_config=quant_config)
96
97
98
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
99
                                           quant_config=quant_config)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
122
        cache_config: Optional[CacheConfig] = None,
123
        quant_config: Optional[QuantizationConfig] = None,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
145
            quant_config=quant_config,
146
147
148
149
150
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
151
            quant_config=quant_config,
152
153
154
155
156
157
158
159
160
161
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
162
163
164
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
165
166
                                  alibi_slopes=alibi_slopes,
                                  quant_config=quant_config)
167
168
169
170
171
172
173
174
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
175
176
177
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
178
179
                                  cache_config=cache_config,
                                  quant_config=quant_config)
180
181
182
183
184
185
            
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

186
187
188
189
190

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
191
192
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
193
194
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
195
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
196
            qkv = qkv[...,:-32]
197
198
199
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
200
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
201
202
203
204
205
206
207
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
208
                 config: PretrainedConfig,
209
                 position_embedding: str,
210
                 cache_config: Optional[CacheConfig] = None,
211
                 quant_config: Optional[QuantizationConfig] = None):
212
213
214
215
216
217
218
219
220
221
222
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
223
            cache_config=cache_config,
224
            quant_config=quant_config,
225
226
227
228
229
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
230
            quant_config=quant_config,
231
232
233
234
235
236
237
238
239
240
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
241
242
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
243
244
245
246
247
248
249
250
251
252
253
254
255
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
256
            attn_metadata=attn_metadata,
257
258
259
260
261
262
263
264
265
266
267
268
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class BaiChuanModel(nn.Module):

    def __init__(self,
269
                 config: PretrainedConfig,
270
                 position_embedding: str,
271
                 cache_config: Optional[CacheConfig] = None,
272
                 quant_config: Optional[QuantizationConfig] = None):
273
274
275
276
277
278
279
280
281
282
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
283
284
            BaiChuanDecoderLayer(config, position_embedding, cache_config,
                                 quant_config)
285
286
287
288
289
290
291
292
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
293
294
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
295
296
297
298
299
300
301
302
303
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
304
                attn_metadata,
305
306
307
308
309
310
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


311
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "W_pack",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
328

329
330
    def __init__(
        self,
331
        config: PretrainedConfig,
332
        position_embedding: str,
333
        cache_config: Optional[CacheConfig] = None,
334
        quant_config: Optional[QuantizationConfig] = None,
335
336
        lora_config: Optional[LoRAConfig] = None,
    ):
337
        super().__init__()
338

339
        self.config = config
340
341
        self.lora_config = lora_config

342
        self.quant_config = quant_config
343
344
        self.model = BaiChuanModel(config, position_embedding, cache_config,
                                   quant_config)
345
346
347
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
348
349
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
350
351
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
352
353
354
355
356
357
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

zhuwenwen's avatar
zhuwenwen committed
358
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
359
360
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
361
362
363
364
365

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
366
367
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
368
        intermediate_tensors: Optional[IntermediateTensors] = None,
369
370
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
371
                                   attn_metadata)
372
373
        return hidden_states

374
375
376
377
378
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
379
        logits = self.logits_processor(self.lm_head, hidden_states,
380
381
382
                                       sampling_metadata)
        return logits

383
384
    def sample(
        self,
385
        logits: torch.Tensor,
386
387
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
388
        next_tokens = self.sampler(logits, sampling_metadata)
389
        return next_tokens
codethazine's avatar
codethazine committed
390

391
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392
393
394
395
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
398
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
399
400
            if "rotary_emb.inv_freq" in name:
                continue
401
            if name == "lm_head.weight":
402
403
                # Unlike Baichuan, Baichuan2 normalizes the head weights.
                # Refer to:
404
405
406
407
408
409
410
411
412
                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                # Distinguish between Baichuan and Baichuan2 by checking the
                # vocab size. This is suggested by
                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                is_baichuan2 = self.config.vocab_size == 125696
                if is_baichuan2:
                    loaded_weight = torch.nn.functional.normalize(
                        loaded_weight)

413
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
414
415
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
416
417
418
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
421
422
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
423
                break
424
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
425
426
427
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
428
429
430
431
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
432
                
433
        if self.use_llama_nn and self.quant_method is None :
zhuwenwen's avatar
zhuwenwen committed
434
435
436
437
            lay_key_words = [
                "self_attn.W_pack.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
438
439
                "mlp.down_proj.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
440
441
442
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
443
444
445
            lay_qkv_words = ["self_attn.W_pack.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
zhuwenwen's avatar
zhuwenwen committed
446
447
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
448
449
450
451
                if matches:      
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
452
453
454
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
455
                                    
zhuwenwen's avatar
zhuwenwen committed
456
457
458
459
460
461
462
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
463

464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.W_pack.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
                    
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0 and self.use_awq_pad:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()

505

506
507
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 13B and Baichuan2 7B/13B."""
508

509
510
511
    def __init__(
        self,
        config,
512
        cache_config: Optional[CacheConfig] = None,
513
        quant_config: Optional[QuantizationConfig] = None,
514
515
        lora_config: Optional[LoRAConfig] = None,
    ):
516
        if config.hidden_size == 4096:  # baichuan2 7b
517
518
            super().__init__(config, "ROPE", cache_config, quant_config,
                             lora_config)
519
        else:  # baichuan 13b, baichuan2 13b
520
521
            super().__init__(config, "ALIBI", cache_config, quant_config,
                             lora_config)
522
523


524
525
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 7B."""
526

527
528
529
    def __init__(
        self,
        config,
530
        cache_config: Optional[CacheConfig] = None,
531
        quant_config: Optional[QuantizationConfig] = None,
532
533
        lora_config: Optional[LoRAConfig] = None,
    ):
534
535
        super().__init__(config, "ROPE", cache_config, quant_config,
                         lora_config)