Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4fc7337
Unverified
Commit
f4fc7337
authored
Aug 19, 2024
by
Zijian Hu
Committed by
GitHub
Aug 19, 2024
Browse files
[Bugfix] support `tie_word_embeddings` for all models (#5724)
parent
0df7ec0b
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
66 additions
and
13 deletions
+66
-13
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+2
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+2
-0
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+2
-0
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-0
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+7
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+3
-0
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+3
-0
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+3
-0
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+2
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+2
-0
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-0
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+8
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+3
-1
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-0
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+6
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+2
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+2
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+4
-0
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-0
No files found.
vllm/model_executor/models/arctic.py
View file @
f4fc7337
...
@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
...
@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
self
.
num_experts_per_tok
=
config
.
num_experts_per_tok
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
vllm/model_executor/models/baichuan.py
View file @
f4fc7337
...
@@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
...
@@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/bart.py
View file @
f4fc7337
...
@@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
...
@@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
config
=
config
self
.
model
=
BartModel
(
config
,
self
.
model
=
BartModel
(
config
,
cache_config
,
cache_config
,
...
...
vllm/model_executor/models/blip2.py
View file @
f4fc7337
...
@@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super
().
__init__
()
super
().
__init__
()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/bloom.py
View file @
f4fc7337
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
...
@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
BloomModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
word_embeddings
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
word_embeddings
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/chatglm.py
View file @
f4fc7337
...
@@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
8192
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
ChatGLMModel
(
config
,
cache_config
,
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
transformer
.
output_layer
.
weight
=
(
self
.
transformer
.
embedding
.
weight
)
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
lm_head
=
self
.
transformer
.
output_layer
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/commandr.py
View file @
f4fc7337
...
@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
...
@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/dbrx.py
View file @
f4fc7337
...
@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
...
@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
if
config
.
tie_word_embeddings
:
raise
ValueError
(
"tie_word_embeddings is not supported for Dbrx models."
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
DbrxModel
(
config
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/deepseek.py
View file @
f4fc7337
...
@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gemma.py
View file @
f4fc7337
...
@@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/gemma2.py
View file @
f4fc7337
...
@@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del
lora_config
# Unused.
del
lora_config
# Unused.
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Gemma2Model
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
logits_processor
=
LogitsProcessor
(
...
...
vllm/model_executor/models/gpt2.py
View file @
f4fc7337
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config
,
cache_config
,
quant_config
,
quant_config
,
prefix
=
"transformer"
)
prefix
=
"transformer"
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
f4fc7337
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
...
@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
lora_config
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
transformer
.
vocab_size
,
self
.
transformer
.
embed_dim
,
org_num_embeddings
=
self
.
config
.
vocab_size
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/gpt_neox.py
View file @
f4fc7337
...
@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
...
@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
self
.
config
.
tie_word_embeddings
:
self
.
embed_out
.
weight
=
self
.
gpt_neox
.
embed_in
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/internlm2.py
View file @
f4fc7337
...
@@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/jais.py
View file @
f4fc7337
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
...
@@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
self
.
transformer
.
wte
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
)
if
hasattr
(
config
,
"width_scale"
):
if
hasattr
(
config
,
"width_scale"
):
self
.
output_logits_scale
=
config
.
width_scale
self
.
output_logits_scale
=
config
.
width_scale
else
:
else
:
...
...
vllm/model_executor/models/llava.py
View file @
f4fc7337
...
@@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
...
@@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
input_ids: Flattened (concatenated) input_ids corresponding to a
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
batch.
pixel_values: The pixels in each input image.
pixel_values: The pixels in each input image.
See also:
See also:
:class:`LlavaImageInputs`
:class:`LlavaImageInputs`
"""
"""
...
...
vllm/model_executor/models/llava_next.py
View file @
f4fc7337
...
@@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
9047, 13566, 29901]`.
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
...
@@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
batch.
batch.
pixel_values: The pixels in each grid patch for each input image.
pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image.
image_sizes: The original `(height, width)` for each input image.
See also:
See also:
:class:`LlavaNextImageInputs`
:class:`LlavaNextImageInputs`
"""
"""
...
...
vllm/model_executor/models/minicpmv.py
View file @
f4fc7337
...
@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
...
@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
...
...
vllm/model_executor/models/mixtral.py
View file @
f4fc7337
...
@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment