baichuan.py 22.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23

24
25
26
import os
import re

27
import math
28
from collections.abc import Iterable
29
from itertools import islice
codethazine's avatar
codethazine committed
30
31

import torch
32
from torch import nn
33
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
34

35
from vllm import envs
36
from vllm.attention.layer import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
40
41
42
43
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
44
45
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
from vllm.model_executor.layers.rotary_embedding import get_rope
54
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.model_executor.model_loader.weight_utils import (
59
60
61
    default_weight_loader,
    row_parallel_weight_loader,
)
62
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
63

64
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
65
66
67
68
69
70
71
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
codethazine's avatar
codethazine committed
72

zhuwenwen's avatar
zhuwenwen committed
73
from vllm import _custom_ops as ops
74
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
codethazine's avatar
codethazine committed
75

76
77
78


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
79
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
80
    base = torch.tensor(
81
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
82
83
84
85
86
87
88
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
89
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
90
91
            dtype=torch.float32,
        )
92
93
94
95
96
97
98
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
99
100
101
102
103
104
105
106
107
    return slopes


class BaiChuanMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
108
        quant_config: QuantizationConfig | None = None,
109
        prefix: str = "",
110
111
112
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
113
114
            hidden_size,
            [intermediate_size] * 2,
115
            bias=False,
116
117
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
118
119
        )
        self.down_proj = RowParallelLinear(
120
121
122
123
124
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
125
        )
126
        if hidden_act != "silu":
127
128
129
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
147
        rope_parameters: dict,
148
        max_position_embeddings: int = 8192,
149
150
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
151
        prefix: str = "",
152
153
154
    ):
        super().__init__()
        self.hidden_size = hidden_size
155
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
156
157
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
158
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
159
        self.head_dim = hidden_size // self.total_num_heads
160
        self.position_embedding = position_embedding
161
162
163
164
165
166
167
168
169
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
170
            quant_config=quant_config,
171
            prefix=f"{prefix}.W_pack",
172
173
174
175
176
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
177
            quant_config=quant_config,
178
            prefix=f"{prefix}.o_proj",
179
180
        )
        # Create the alibi slopes and slice them.
181
        if self.position_embedding == "ALIBI":
182
183
184
185
186
187
188
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
189
190
191
192
193
194
195
196
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                scaling,
                alibi_slopes=alibi_slopes,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
197
198
199
200
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                max_position=self.max_position_embeddings,
201
                rope_parameters=rope_parameters,
202
203
            )
            self.scaling = self.head_dim**-0.5
204
205
206
207
208
209
210
211
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
212
213
214
215
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
216
217
218
219
220
221
222

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
223
224
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
225
        q, k, v = qkv.chunk(chunks=3, dim=-1)
226
        if self.position_embedding != "ALIBI":
227
            q, k = self.rotary_emb(positions, q, k)
228
        attn_output = self.attn(q, k, v)
229
230
231
232
233
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):
234
235
236
237
    def __init__(
        self,
        config: PretrainedConfig,
        position_embedding: str,
238
239
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
240
241
        prefix: str = "",
    ):
242
243
        super().__init__()
        self.hidden_size = config.hidden_size
244
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
245
246
247
248
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
249
            rope_parameters=getattr(config, "rope_parameters", None),
250
            max_position_embeddings=max_position_embeddings,
251
            cache_config=cache_config,
252
            quant_config=quant_config,
253
            prefix=f"{prefix}.self_attn",
254
255
256
257
258
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
259
            quant_config=quant_config,
260
            prefix=f"{prefix}.mlp",
261
        )
262
263
264
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
265
266
267
268
269
270
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
271
        residual: torch.Tensor | None,
272
    ) -> tuple[torch.Tensor, torch.Tensor]:
273
274
275
276
277
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
278
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
279
280
281
282
283
284
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
285
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
286
287
288
289
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


290
@support_torch_compile
291
class BaiChuanModel(nn.Module):
292
293
294
295
296
297
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
298
        super().__init__()
299
300
301
302
303

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

304
305
306
307
308
309
310
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
311
312
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
313
314
315
            lambda prefix: BaiChuanDecoderLayer(
                config, position_embedding, cache_config, quant_config, prefix=prefix
            ),
316
317
            prefix=f"{prefix}.layers",
        )
318
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
319
320
321
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
zhuwenwen's avatar
zhuwenwen committed
322
323
324
325
326
327
328
329
330
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
331

332
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
333
        return self.embed_tokens(input_ids)
334
335
336

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
337
        input_ids: torch.Tensor,
338
        positions: torch.Tensor,
339
340
341
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
342
        if get_pp_group().is_first_rank:
343
344
345
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
346
                hidden_states = self.embed_input_ids(input_ids)
347
348
349
350
351
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
352
        for layer in islice(self.layers, self.start_layer, self.end_layer):
353
354
355
356
357
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
358
        if not get_pp_group().is_last_rank:
359
360
361
362
363
364
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
365
366
367
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

368
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
369
370
371
372
373
374
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
375
        loaded_params: set[str] = set()
376
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
377
378
            if "rotary_emb.inv_freq" in name:
                continue
379

380
            for param_name, weight_name, shard_id in stacked_params_mapping:
codethazine's avatar
codethazine committed
381
382
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
383
384
385
386
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
387
388
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
389
                param = params_dict[name]
390
391
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
392
                break
393
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
394
395
396
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
397
398
                if is_pp_missing_parameter(name, self):
                    continue
399
                param = params_dict[name]
400
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
401
                weight_loader(param, loaded_weight)
402
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
403
            
404
        if self.use_llama_nn and self.quant_method is None :
zhuwenwen's avatar
zhuwenwen committed
405
406
407
408
            lay_key_words = [
                "self_attn.W_pack.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
409
410
                "mlp.down_proj.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
411
412
413
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
414
415
            # lay_qkv_words = ["self_attn.W_pack.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
416
            
zhuwenwen's avatar
zhuwenwen committed
417
418
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
419
420
421
422
423
424
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
425
                matches = re.findall(combined_words, layername)
426
                if matches:      
zhuwenwen's avatar
zhuwenwen committed
427
428
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
429
                        
zhuwenwen's avatar
zhuwenwen committed
430
431
432
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
433
                                    
zhuwenwen's avatar
zhuwenwen committed
434
435
436
437
438
439
440
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
441
        else:
zhuwenwen's avatar
zhuwenwen committed
442
            os.environ['LM_NN'] = '0'
443
444
            os.environ['LLAMA_NN'] = '0'
            
445
446
447
448
449
450
451
452
        # if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
        #     lay_key_words = [
        #         "self_attn.W_pack.qweight",
        #         "self_attn.o_proj.qweight",
        #         "mlp.gate_up_proj.qweight",
        #         "mlp.down_proj.qweight"
        #     ]
        #     combined_words = "|".join(lay_key_words)
453
            
454
455
        #     for layername in loaded_params:
        #         weight = params_dict[layername]
456
                
457
458
459
460
461
462
        #         matches = re.findall(combined_words, layername)
        #         if matches:
        #             qweight =params_dict[layername]
        #             qzeros=params_dict[layername.replace("qweight", "qzeros")]
        #             scales=params_dict[layername.replace("qweight", "scales")]
        #             zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
463
                    
464
        #             group_size= self.quant_config.group_size 
465
                   
466
467
468
        #             dim_n = scales.data.shape[1]
        #             dim_k = qweight.data.shape[0]
        #             pad_group=2              
469
                    
470
        #             _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
471
                    
472
        #             sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
473
                    
474
475
        #             zeros_and_scalse.data.copy_(sz)
        #             qweight.data.copy_(_qw)
476
                    
477
478
479
        #             #reshape
        #             zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
        #             qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
480
                
481
482
483
484
485
        #             if dim_k % 4096==0 and self.use_awq_pad:
        #                 zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
        #                 zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
        #                 qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
        #                 qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()  
486
        return loaded_params
487

488

489
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
490
491
492
493
494
495
496
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
497

498
499
    def __init__(
        self,
500
        *,
501
502
503
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
504
    ):
505
        super().__init__()
506
507
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
508

509
        self.config = config
510

511
        self.tp_size = get_tensor_model_parallel_world_size()
512
        self.quant_config = quant_config
513
514
515
516
517
518
519
520
521
522
523
        self.model = BaiChuanModel(
            vllm_config=vllm_config,
            prefix=prefix,
            position_embedding=position_embedding,
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
524
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
525
526
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
527
        self.logits_processor = LogitsProcessor(config.vocab_size)
528
        self.make_empty_intermediate_tensors = (
529
530
            self.model.make_empty_intermediate_tensors
        )
531

532
533
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
534

535
536
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
537
        input_ids: torch.Tensor,
538
        positions: torch.Tensor,
539
540
541
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
542
543
544
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
545
546
        return hidden_states

547
548
549
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
550
    ) -> torch.Tensor | None:
551
        logits = self.logits_processor(self.lm_head, hidden_states)
552
553
        return logits

554
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
555
556
557
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

558
    def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
559
560
561
562
563
564
565
566
567
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
568
569
570
571
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
572
573


574
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
575
576
577
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
578

579
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
580
        config = vllm_config.model_config.hf_config
581
        if config.hidden_size == 4096:  # baichuan2 7b
582
583
584
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
            )
585
        else:  # baichuan 13b, baichuan2 13b
586
587
588
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI"
            )
589
590


591
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
592
593
594
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
595

596
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
597
598
599
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
        )