Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4fc7337
Unverified
Commit
f4fc7337
authored
Aug 19, 2024
by
Zijian Hu
Committed by
GitHub
Aug 19, 2024
Browse files
[Bugfix] support `tie_word_embeddings` for all models (#5724)
parent
0df7ec0b
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
66 additions
and
13 deletions
+66
-13
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+2
-0
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+2
-0
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-0
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+7
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+3
-0
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+3
-0
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+3
-0
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+2
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+2
-0
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-0
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+8
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+3
-1
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-0
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+6
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+2
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+2
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+4
-0
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-0
No files found.
vllm/model_executor/models/arctic.py
View file @
f4fc7337
...
...
@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
vllm/model_executor/models/baichuan.py
View file @
f4fc7337
...
...
@@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/bart.py
View file @
f4fc7337
...
...
@@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
super
().
__init__
()
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
model
=
BartModel
(
config
,
cache_config
,
...
...
vllm/model_executor/models/blip2.py
View file @
f4fc7337
...
...
@@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super
().
__init__
()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/bloom.py
View file @
f4fc7337
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
word_embeddings
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
f4fc7337
...
...
@@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
embedding
.
weight
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
f4fc7337
...
...
@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/dbrx.py
View file @
f4fc7337
...
...
@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
super
().
__init__
()
self
.
config
=
config
if
config
.
tie_word_embeddings
:
raise
ValueError
(
"tie_word_embeddings is not supported for Dbrx models."
)
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/deepseek.py
View file @
f4fc7337
...
...
@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gemma.py
View file @
f4fc7337
...
...
@@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super
().
__init__
()
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/gemma2.py
View file @
f4fc7337
...
...
@@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
...
...
vllm/model_executor/models/gpt2.py
View file @
f4fc7337
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config
,
quant_config
,
prefix
=
"transformer"
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
f4fc7337
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
transformer
.
vocab_size
,
self
.
transformer
.
embed_dim
,
org_num_embeddings
=
self
.
config
.
vocab_size
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/gpt_neox.py
View file @
f4fc7337
...
...
@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
...
...
@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
embed_out
.
weight
=
self
.
gpt_neox
.
embed_in
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/internlm2.py
View file @
f4fc7337
...
...
@@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/jais.py
View file @
f4fc7337
...
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
else
:
...
...
vllm/model_executor/models/llava.py
View file @
f4fc7337
vllm/model_executor/models/llava_next.py
View file @
f4fc7337
vllm/model_executor/models/minicpmv.py
View file @
f4fc7337
...
...
@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/mixtral.py
View file @
f4fc7337
...
...
@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment