Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3339d8c
Unverified
Commit
a3339d8c
authored
Feb 22, 2025
by
fzyzcjy
Committed by
GitHub
Feb 21, 2025
Browse files
Bug: Fix weight loader error when LM head weights are tied (#3766)
parent
14d90617
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
16 additions
and
0 deletions
+16
-0
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+2
-0
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+2
-0
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+2
-0
python/sglang/srt/models/olmo.py
python/sglang/srt/models/olmo.py
+2
-0
python/sglang/srt/models/phi3_small.py
python/sglang/srt/models/phi3_small.py
+2
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+2
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+2
-0
python/sglang/srt/models/torch_native_llama.py
python/sglang/srt/models/torch_native_llama.py
+2
-0
No files found.
python/sglang/srt/models/llama.py
View file @
a3339d8c
...
@@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module):
...
@@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module):
continue
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
# Handle FP8 kv-scale remapping
# Handle FP8 kv-scale remapping
if
"scale"
in
name
:
if
"scale"
in
name
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
...
...
python/sglang/srt/models/minicpm.py
View file @
a3339d8c
...
@@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module):
...
@@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
python/sglang/srt/models/minicpm3.py
View file @
a3339d8c
...
@@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module):
...
@@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
python/sglang/srt/models/olmo.py
View file @
a3339d8c
...
@@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module):
...
@@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
python/sglang/srt/models/phi3_small.py
View file @
a3339d8c
...
@@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module):
...
@@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module):
continue
continue
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
python/sglang/srt/models/qwen2.py
View file @
a3339d8c
...
@@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
a3339d8c
...
@@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
python/sglang/srt/models/torch_native_llama.py
View file @
a3339d8c
...
@@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
...
@@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
continue
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment