Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
12cb115d
Unverified
Commit
12cb115d
authored
Sep 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 03, 2024
Browse files
Fix llama2 weight loader (#1317)
parent
c500f96b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
29 deletions
+8
-29
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+4
-25
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+4
-4
No files found.
python/sglang/srt/models/exaone.py
View file @
12cb115d
...
...
@@ -323,27 +323,6 @@ class ExaoneForCausalLM(nn.Module):
sample_output
=
self
.
sampler
(
logits_output
,
input_metadata
.
sampling_info
)
return
sample_output
,
logits_output
def
get_module_name
(
self
,
name
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id, num_shard)
(
"qkv_proj"
,
"q_proj"
,
"q"
,
3
),
(
"qkv_proj"
,
"k_proj"
,
"k"
,
3
),
(
"qkv_proj"
,
"v_proj"
,
"v"
,
3
),
(
"gate_up_proj"
,
"c_fc_0"
,
0
,
2
),
(
"gate_up_proj"
,
"c_fc_1"
,
1
,
2
),
]
for
param_name
,
weight_name
,
shard_id
,
num_shard
in
stacked_params_mapping
:
if
weight_name
in
name
:
return
(
name
.
replace
(
weight_name
,
param_name
)[:
-
len
(
".weight"
)],
num_shard
,
)
return
name
[:
-
len
(
".weight"
)],
1
def
get_num_params
(
self
):
params_dict
=
dict
(
self
.
named_parameters
())
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
@@ -357,13 +336,13 @@ class ExaoneForCausalLM(nn.Module):
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
return
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
continue
name
=
name
.
replace
(
"attn.attention"
,
"self_attn"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
@@ -380,7 +359,7 @@ class ExaoneForCausalLM(nn.Module):
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
python/sglang/srt/models/llama.py
View file @
12cb115d
...
...
@@ -334,13 +334,13 @@ class LlamaForCausalLM(nn.Module):
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
return
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
...
...
@@ -356,7 +356,7 @@ class LlamaForCausalLM(nn.Module):
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment