Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4fc7337
Unverified
Commit
f4fc7337
authored
Aug 19, 2024
by
Zijian Hu
Committed by
GitHub
Aug 19, 2024
Browse files
[Bugfix] support `tie_word_embeddings` for all models (#5724)
parent
0df7ec0b
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
24 additions
and
3 deletions
+24
-3
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+2
-0
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+6
-2
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+2
-0
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+2
-0
vllm/model_executor/models/phi3_small.py
vllm/model_executor/models/phi3_small.py
+2
-1
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+2
-0
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+2
-0
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+2
-0
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+2
-0
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+2
-0
No files found.
vllm/model_executor/models/mixtral_quant.py
View file @
f4fc7337
...
@@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
...
@@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/opt.py
View file @
f4fc7337
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
...
@@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
OPTModel
(
config
,
cache_config
,
quant_config
)
self
.
model
=
OPTModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
model
.
decoder
.
embed_tokens
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
decoder
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/orion.py
View file @
f4fc7337
...
@@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
...
@@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/phi.py
View file @
f4fc7337
...
@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
...
@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
# lm_head use bias, cannot share word embeddings
assert
not
config
.
tie_word_embeddings
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/phi3_small.py
View file @
f4fc7337
...
@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
...
@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
@@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
...
@@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
self
.
lm_head
.
weight
.
data
.
copy_
(
self
.
model
.
embed_tokens
.
weight
.
data
)
vllm/model_executor/models/phi3v.py
View file @
f4fc7337
...
@@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
...
@@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/qwen.py
View file @
f4fc7337
...
@@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
...
@@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
transformer
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
f4fc7337
...
@@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/stablelm.py
View file @
f4fc7337
...
@@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
...
@@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/xverse.py
View file @
f4fc7337
...
@@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
...
@@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment