Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ce64e2d
Unverified
Commit
4ce64e2d
authored
May 23, 2025
by
Mengqing Cao
Committed by
GitHub
May 23, 2025
Browse files
[Bugfix][Model] Fix baichuan model loader for tp (#18597)
Signed-off-by:
Mengqing Cao
<
cmq0113@163.com
>
parent
fbb13a2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+7
-4
No files found.
vllm/model_executor/models/baichuan.py
View file @
4ce64e2d
...
...
@@ -42,7 +42,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -384,7 +385,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
...
...
@@ -438,8 +439,10 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
is_baichuan2
=
self
.
config
.
vocab_size
==
125696
if
is_baichuan2
:
loaded_weight
=
torch
.
nn
.
functional
.
normalize
(
loaded_weight
)
default_weight_loader
(
param
,
loaded_weight
)
if
self
.
tp_size
>
1
:
row_parallel_weight_loader
(
param
,
loaded_weight
)
else
:
default_weight_loader
(
param
,
loaded_weight
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment