Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d4c7755c
Unverified
Commit
d4c7755c
authored
Aug 02, 2023
by
Qing
Committed by
GitHub
Aug 01, 2023
Browse files
fix biachuan-7b tp (#598)
Co-authored-by:
wq.chu
<
wq.chu@tianrang-inc.com
>
parent
aa39e42c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
10 deletions
+37
-10
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+37
-10
No files found.
vllm/model_executor/models/baichuan.py
View file @
d4c7755c
...
@@ -251,8 +251,8 @@ class BaiChuanForCausalLM(nn.Module):
...
@@ -251,8 +251,8 @@ class BaiChuanForCausalLM(nn.Module):
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"W_pack.weight"
,
"embed_tokens.weight"
,
"
gate_proj.weight"
,
"up_proj.weight"
"
lm_head.weight"
,
]
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
...
@@ -260,7 +260,8 @@ class BaiChuanForCausalLM(nn.Module):
...
@@ -260,7 +260,8 @@ class BaiChuanForCausalLM(nn.Module):
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
...
@@ -268,15 +269,37 @@ class BaiChuanForCausalLM(nn.Module):
...
@@ -268,15 +269,37 @@ class BaiChuanForCausalLM(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
# Consider padding in the vocab size.
param
=
state_dict
[
name
]
padded_vocab_size
=
param
.
shape
[
0
]
*
tp_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
if
"W_pack"
in
name
:
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
num_heads
=
total_num_heads
//
tp_world_size
head_start
=
tp_rank
*
num_heads
head_end
=
(
tp_rank
+
1
)
*
num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
[:,
head_start
:
head_end
,
:,
:]
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
is_gate_up_weight
=
False
is_gate_up_weight
=
False
for
stride_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"up_proj"
]):
for
stride_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"up_proj"
]):
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
2
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
loaded_weight
=
loaded_weight
[
shard_size
*
tp_rank
:
shard_size
*
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tp_rank
+
1
)]
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
assert
param_slice
.
shape
==
loaded_weight
.
shape
...
@@ -287,7 +310,11 @@ class BaiChuanForCausalLM(nn.Module):
...
@@ -287,7 +310,11 @@ class BaiChuanForCausalLM(nn.Module):
continue
continue
param
=
state_dict
[
name
]
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
load_tensor_parallel_weights
(
self
.
_column_parallel_weights
,
param
,
self
.
_row_parallel_weights
,
loaded_weight
,
tensor_model_parallel_rank
)
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment