Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07c11cf4
Unverified
Commit
07c11cf4
authored
Oct 10, 2024
by
Isotr0py
Committed by
GitHub
Oct 10, 2024
Browse files
[Bugfix] Fix lm_head weights tying with lora for llama (#9227)
parent
f3a507f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
2 deletions
+12
-2
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+10
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-1
No files found.
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
07c11cf4
...
@@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding):
...
@@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
,
quant_config
,
org_num_embeddings
,
padding_size
,
quant_config
,
prefix
)
prefix
)
self
.
quant_config
=
quant_config
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
torch
.
empty
(
self
.
num_embeddings_per_partition
,
...
@@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding):
...
@@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding):
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
tie_weights
(
self
,
embed_tokens
:
VocabParallelEmbedding
):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if
self
.
quant_config
and
self
.
quant_config
.
get_name
()
==
"gguf"
:
return
embed_tokens
else
:
self
.
weight
=
embed_tokens
.
weight
return
self
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
del
input_
del
input_
raise
RuntimeError
(
"LMHead's weights should be used in the sampler."
)
raise
RuntimeError
(
"LMHead's weights should be used in the sampler."
)
vllm/model_executor/models/llama.py
View file @
07c11cf4
...
@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment