baichuan.py 23.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23
import math
24
from collections.abc import Iterable
25
from itertools import islice
26
from typing import Optional, Union
codethazine's avatar
codethazine committed
27
28

import torch
29
from torch import nn
30
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
31

zhuwenwen's avatar
zhuwenwen committed
32
import os
zhuwenwen's avatar
zhuwenwen committed
33
import re
34
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.attention import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
40
                              get_tensor_model_parallel_world_size)
41
42
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
43
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
44
45
                                               QKVParallelLinear,
                                               RowParallelLinear)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
    ParallelLMHead, VocabParallelEmbedding)
51
52
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, row_parallel_weight_loader)
53
from vllm.model_executor.sampling_metadata import SamplingMetadata
54
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
55

56
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
57
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
58
                    make_empty_intermediate_tensors_factory, make_layers)
codethazine's avatar
codethazine committed
59

zhuwenwen's avatar
zhuwenwen committed
60
from vllm import _custom_ops as ops
61
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
codethazine's avatar
codethazine committed
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
97
        quant_config: Optional[QuantizationConfig] = None,
98
99
100
101
102
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
103
            quant_config=quant_config)
104
105
106
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
107
                                           quant_config=quant_config)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
130
        cache_config: Optional[CacheConfig] = None,
131
        quant_config: Optional[QuantizationConfig] = None,
132
        prefix: str = "",
133
134
135
136
137
138
139
140
141
142
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
143
        self.position_embedding = position_embedding
144
145
146
147
148
149
150
151
152
153
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
154
            quant_config=quant_config,
155
156
157
158
159
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
160
            quant_config=quant_config,
161
162
        )
        # Create the alibi slopes and slice them.
163
        if self.position_embedding == "ALIBI":
164
165
166
167
168
169
170
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
171
172
173
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
174
                                  alibi_slopes=alibi_slopes,
175
176
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
177
178
179
180
181
182
183
184
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
185
186
187
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
188
                                  cache_config=cache_config,
189
190
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
191
            
zhuwenwen's avatar
zhuwenwen committed
192
193
194
195
            self.quant_method = None
            if quant_config is not None:
                self.quant_method=quant_config.get_name()
                self.quant_config=quant_config
196
197
198
199
200
201
202

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
203
204
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
205
        q, k, v = qkv.chunk(chunks=3, dim=-1)
206
        if self.position_embedding != "ALIBI":
207
            q, k = self.rotary_emb(positions, q, k)
208
        attn_output = self.attn(q, k, v)
209
210
211
212
213
214
215
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
216
                 config: PretrainedConfig,
217
                 position_embedding: str,
218
                 cache_config: Optional[CacheConfig] = None,
219
220
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
221
222
223
224
225
226
227
228
229
230
231
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
232
            cache_config=cache_config,
233
            quant_config=quant_config,
234
            prefix=f"{prefix}.self_attn",
235
236
237
238
239
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
240
            quant_config=quant_config,
241
242
243
244
245
246
247
248
249
250
251
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
252
    ) -> tuple[torch.Tensor, torch.Tensor]:
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


272
@support_torch_compile
273
274
class BaiChuanModel(nn.Module):

275
276
277
278
279
280
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
281
        super().__init__()
282
283
284
285
286

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

287
288
289
290
291
292
293
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
294
295
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
296
297
298
299
300
            lambda prefix: BaiChuanDecoderLayer(config,
                                                position_embedding,
                                                cache_config,
                                                quant_config,
                                                prefix=prefix),
301
302
            prefix=f"{prefix}.layers",
        )
303
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
304
305
306
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
307
308
309
310
311
312
313
314
315
316
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
317

318
319
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
320
321
322
323
324

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
325
        intermediate_tensors: Optional[IntermediateTensors],
326
        inputs_embeds: Optional[torch.Tensor] = None,
327
328
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
329
330
331
332
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
333
334
335
336
337
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
338
        for layer in islice(self.layers, self.start_layer, self.end_layer):
339
340
341
342
343
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
344
345
346
347
348
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
349
350
351
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

352
353
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
354
355
356
357
358
359
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
360
        loaded_params: set[str] = set()
361
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
362
363
            if "rotary_emb.inv_freq" in name:
                continue
364

365
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
366
367
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
368
369
370
371
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
372
373
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
374
                param = params_dict[name]
375
376
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
377
                break
378
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
379
380
381
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
382
383
                if is_pp_missing_parameter(name, self):
                    continue
384
385
386
387
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
388
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
389
            
390
        if self.use_llama_nn and self.quant_method is None :
zhuwenwen's avatar
zhuwenwen committed
391
392
393
394
            lay_key_words = [
                "self_attn.W_pack.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
395
396
                "mlp.down_proj.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
397
398
399
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
400
401
            # lay_qkv_words = ["self_attn.W_pack.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
402
            
zhuwenwen's avatar
zhuwenwen committed
403
404
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
405
406
407
408
409
410
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
411
                matches = re.findall(combined_words, layername)
412
                if matches:      
zhuwenwen's avatar
zhuwenwen committed
413
414
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
415
                        
zhuwenwen's avatar
zhuwenwen committed
416
417
418
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
419
                                    
zhuwenwen's avatar
zhuwenwen committed
420
421
422
423
424
425
426
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
427
        else:
zhuwenwen's avatar
zhuwenwen committed
428
            os.environ['LM_NN'] = '0'
429
430
            os.environ['LLAMA_NN'] = '0'
            
431
432
433
434
435
436
437
438
        # if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
        #     lay_key_words = [
        #         "self_attn.W_pack.qweight",
        #         "self_attn.o_proj.qweight",
        #         "mlp.gate_up_proj.qweight",
        #         "mlp.down_proj.qweight"
        #     ]
        #     combined_words = "|".join(lay_key_words)
439
            
440
441
        #     for layername in loaded_params:
        #         weight = params_dict[layername]
442
                
443
444
445
446
447
448
        #         matches = re.findall(combined_words, layername)
        #         if matches:
        #             qweight =params_dict[layername]
        #             qzeros=params_dict[layername.replace("qweight", "qzeros")]
        #             scales=params_dict[layername.replace("qweight", "scales")]
        #             zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
449
                    
450
        #             group_size= self.quant_config.group_size 
451
                   
452
453
454
        #             dim_n = scales.data.shape[1]
        #             dim_k = qweight.data.shape[0]
        #             pad_group=2              
455
                    
456
        #             _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
457
                    
458
        #             sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
459
                    
460
461
        #             zeros_and_scalse.data.copy_(sz)
        #             qweight.data.copy_(_qw)
462
                    
463
464
465
        #             #reshape
        #             zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
        #             qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
466
                
467
468
469
470
471
        #             if dim_k % 4096==0 and self.use_awq_pad:
        #                 zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
        #                 zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
        #                 qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
        #                 qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()  
472
        return loaded_params
473

474

475
476
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                              SupportsQuant):
477
478
479
480
481
482
483
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
484

485
486
    def __init__(
        self,
487
        *,
488
489
490
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
491
    ):
492
        super().__init__()
493
494
495
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
496
        self.config = config
497
        self.lora_config = lora_config
498
        self.tp_size = get_tensor_model_parallel_world_size()
499
        self.quant_config = quant_config
500
501
502
        self.model = BaiChuanModel(vllm_config=vllm_config,
                                   prefix=prefix,
                                   position_embedding=position_embedding)
503
504
505
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
506
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
507
508
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
509
        self.logits_processor = LogitsProcessor(config.vocab_size)
510
511
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
512
        
513

514
515
516
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

517
518
519
520
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
521
        intermediate_tensors: Optional[IntermediateTensors] = None,
522
        inputs_embeds: Optional[torch.Tensor] = None,
523
    ) -> Union[torch.Tensor, IntermediateTensors]:
524
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
525
                                   inputs_embeds)
526
527
        return hidden_states

528
529
530
531
532
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
533
        logits = self.logits_processor(self.lm_head, hidden_states,
534
535
536
                                       sampling_metadata)
        return logits

537
538
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def lm_head_weight_loader(self, param: nn.Parameter,
                              loaded_weight: torch.Tensor):
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
553
554
555
556
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
557
558


559
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
560
561
562
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
563

564
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
565
        config = vllm_config.model_config.hf_config
566
        if config.hidden_size == 4096:  # baichuan2 7b
567
568
569
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ROPE")
570
        else:  # baichuan 13b, baichuan2 13b
571
572
573
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ALIBI")
574
575


576
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
577
578
579
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
580

581
582
583
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
zhuwenwen's avatar
zhuwenwen committed
584
                         position_embedding="ROPE")