baichuan.py 23.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

codethazine's avatar
codethazine committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
22
import math
23
from typing import Iterable, Optional, Set, Tuple, Union
codethazine's avatar
codethazine committed
24
25

import torch
26
from torch import nn
27
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
28

zhuwenwen's avatar
zhuwenwen committed
29
import os
zhuwenwen's avatar
zhuwenwen committed
30
import re
31
import vllm.envs as envs
zhuwenwen's avatar
zhuwenwen committed
32

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
37
                              get_tensor_model_parallel_world_size)
38
39
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
41
42
                                               QKVParallelLinear,
                                               RowParallelLinear)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.quantization import QuantizationConfig
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
46
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
52

53
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
54
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
55
                    make_empty_intermediate_tensors_factory, make_layers)
codethazine's avatar
codethazine committed
56

zhuwenwen's avatar
zhuwenwen committed
57
from vllm import _custom_ops as ops
58
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
codethazine's avatar
codethazine committed
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
94
        quant_config: Optional[QuantizationConfig] = None,
95
96
97
98
99
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
100
            quant_config=quant_config)
101
102
103
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
104
                                           quant_config=quant_config)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
127
        cache_config: Optional[CacheConfig] = None,
128
        quant_config: Optional[QuantizationConfig] = None,
129
        prefix: str = "",
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
151
            quant_config=quant_config,
152
153
154
155
156
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
157
            quant_config=quant_config,
158
159
160
161
162
163
164
165
166
167
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
168
169
170
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
171
                                  alibi_slopes=alibi_slopes,
172
173
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
174
175
176
177
178
179
180
181
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
182
183
184
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
185
                                  cache_config=cache_config,
186
187
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
188
            
zhuwenwen's avatar
zhuwenwen committed
189
190
191
192
            self.quant_method = None
            if quant_config is not None:
                self.quant_method=quant_config.get_name()
                self.quant_config=quant_config
193
194
195
196
197
198
199

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
200
201
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
202
203
204
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
205
        attn_output = self.attn(q, k, v)
206
207
208
209
210
211
212
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
213
                 config: PretrainedConfig,
214
                 position_embedding: str,
215
                 cache_config: Optional[CacheConfig] = None,
216
217
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
218
219
220
221
222
223
224
225
226
227
228
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
229
            cache_config=cache_config,
230
            quant_config=quant_config,
231
            prefix=f"{prefix}.self_attn",
232
233
234
235
236
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
237
            quant_config=quant_config,
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


269
@support_torch_compile
270
271
class BaiChuanModel(nn.Module):

272
273
274
275
276
277
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
278
        super().__init__()
279
280
281
282
283

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

284
285
286
287
288
289
290
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
291
292
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
293
294
295
296
297
            lambda prefix: BaiChuanDecoderLayer(config,
                                                position_embedding,
                                                cache_config,
                                                quant_config,
                                                prefix=prefix),
298
299
            prefix=f"{prefix}.layers",
        )
300
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
302
303
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
304
305
306
307
308
309
310
311
312
313
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
314

315
316
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
317
318
319
320
321

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors],
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
325
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
326
327
328
329
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
330
331
332
333
334
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
335
        for layer in self.layers[self.start_layer:self.end_layer]:
336
337
338
339
340
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
341
342
343
344
345
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
346
347
348
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

349
350
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
351
352
353
354
355
356
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
357
        loaded_params: Set[str] = set()
358
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
359
360
            if "rotary_emb.inv_freq" in name:
                continue
361

362
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
363
364
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
365
366
367
368
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
369
370
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
371
                param = params_dict[name]
372
373
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
374
                break
375
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
376
377
378
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
379
380
                if is_pp_missing_parameter(name, self):
                    continue
381
382
383
384
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
385
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
386
            
387
        if self.use_llama_nn and self.quant_method is None :
zhuwenwen's avatar
zhuwenwen committed
388
389
390
391
            lay_key_words = [
                "self_attn.W_pack.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
392
393
                "mlp.down_proj.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
394
395
396
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
397
398
            # lay_qkv_words = ["self_attn.W_pack.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
399
            
zhuwenwen's avatar
zhuwenwen committed
400
401
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
402
                matches = re.findall(combined_words, layername)
403
                if matches:      
zhuwenwen's avatar
zhuwenwen committed
404
405
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
406
                        
zhuwenwen's avatar
zhuwenwen committed
407
408
409
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
410
                                    
zhuwenwen's avatar
zhuwenwen committed
411
412
413
414
415
416
417
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
418
        else:
zhuwenwen's avatar
zhuwenwen committed
419
            os.environ['LM_NN'] = '0'
420
421
422
            os.environ['LLAMA_NN'] = '0'
            
        if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
423
424
425
426
427
428
429
430
            lay_key_words = [
                "self_attn.W_pack.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
431
432
            for layername in loaded_params:
                weight = params_dict[layername]
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
                    
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0 and self.use_awq_pad:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
zhuwenwen's avatar
zhuwenwen committed
462
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()  
463
        return loaded_params
464

465

466
467
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                              SupportsQuant):
468
469
470
471
472
473
474
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
475

476
477
    def __init__(
        self,
478
        *,
479
480
481
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
482
    ):
483
        super().__init__()
484
485
486
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
487
        self.config = config
488
489
        self.lora_config = lora_config

490
        self.quant_config = quant_config
491
492
493
        self.model = BaiChuanModel(vllm_config=vllm_config,
                                   prefix=prefix,
                                   position_embedding=position_embedding)
494
495
496
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
497
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
498
499
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
500
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
501
        self.sampler = get_sampler()
502
503
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
504
        
505

506
507
508
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

509
510
511
512
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
513
        intermediate_tensors: Optional[IntermediateTensors] = None,
514
        inputs_embeds: Optional[torch.Tensor] = None,
515
    ) -> Union[torch.Tensor, IntermediateTensors]:
516
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
517
                                   inputs_embeds)
518
519
        return hidden_states

520
521
522
523
524
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
525
        logits = self.logits_processor(self.lm_head, hidden_states,
526
527
528
                                       sampling_metadata)
        return logits

529
530
    def sample(
        self,
531
        logits: torch.Tensor,
532
533
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
534
        next_tokens = self.sampler(logits, sampling_metadata)
535
        return next_tokens
codethazine's avatar
codethazine committed
536

537
538
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def lm_head_weight_loader(self, param: nn.Parameter,
                              loaded_weight: torch.Tensor):
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)

        default_weight_loader(param, loaded_weight)
555
556


557
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
558
559
560
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
561

562
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
563
        config = vllm_config.model_config.hf_config
564
        if config.hidden_size == 4096:  # baichuan2 7b
565
566
567
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ROPE")
568
        else:  # baichuan 13b, baichuan2 13b
569
570
571
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ALIBI")
572
573


574
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
575
576
577
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
578

579
580
581
582
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         position_embedding="ROPE")