Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c93187c
Unverified
Commit
3c93187c
authored
Sep 24, 2024
by
TianyiQ
Committed by
GitHub
Sep 24, 2024
Browse files
Add support for tie_word_embeddings when loading weights + support for SmolLM (#1508)
parent
fb2d0680
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
0 deletions
+10
-0
README.md
README.md
+1
-0
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+8
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-0
No files found.
README.md
View file @
3c93187c
...
@@ -263,6 +263,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
...
@@ -263,6 +263,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
-
BaiChuan2
-
BaiChuan2
-
MiniCPM / MiniCPM 3
-
MiniCPM / MiniCPM 3
-
XVERSE / XVERSE MoE
-
XVERSE / XVERSE MoE
-
SmolLM
**Embedding Models**
**Embedding Models**
...
...
python/sglang/srt/models/llama.py
View file @
3c93187c
...
@@ -403,6 +403,14 @@ class LlamaForCausalLM(nn.Module):
...
@@ -403,6 +403,14 @@ class LlamaForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
(
hasattr
(
self
.
config
,
"tie_word_embeddings"
)
and
self
.
config
.
tie_word_embeddings
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param
=
self
.
lm_head
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
self
.
model
.
embed_tokens
.
weight
)
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
apply_torchao_config_
(
self
,
params_dict
,
set
([
"proj.weight"
]))
...
...
test/srt/models/test_generation_models.py
View file @
3c93187c
...
@@ -51,6 +51,7 @@ CI_MODELS = [
...
@@ -51,6 +51,7 @@ CI_MODELS = [
# All other models
# All other models
ALL_OTHER_MODELS
=
[
ALL_OTHER_MODELS
=
[
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"HuggingFaceTB/SmolLM-135M-Instruct"
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment