Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f89d18ff
Unverified
Commit
f89d18ff
authored
Nov 10, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[6/N] pass whole config to inner model (#10205)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
f0f2e563
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
189 additions
and
279 deletions
+189
-279
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+10
-12
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+11
-10
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+10
-10
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+13
-21
vllm/model_executor/models/granitemoe.py
vllm/model_executor/models/granitemoe.py
+12
-21
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+12
-17
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+8
-16
vllm/model_executor/models/internlm2_ve.py
vllm/model_executor/models/internlm2_ve.py
+9
-14
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-3
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+11
-10
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+12
-18
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+12
-32
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+3
-3
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+3
-3
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+3
-3
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+3
-3
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+13
-18
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+17
-21
vllm/model_executor/models/minicpm3.py
vllm/model_executor/models/minicpm3.py
+5
-7
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+19
-37
No files found.
vllm/model_executor/models/gpt_bigcode.py
View file @
f89d18ff
...
...
@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -189,15 +189,14 @@ class GPTBigCodeBlock(nn.Module):
@
support_torch_compile
class
GPTBigCodeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
assert
not
config
.
add_cross_attention
...
...
@@ -265,7 +264,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
...
...
@@ -273,8 +271,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant
_config
,
lora_config
)
self
.
transformer
=
GPTBigCodeModel
(
vllm_config
=
vllm
_config
,
prefix
=
prefix
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
...
...
vllm/model_executor/models/gpt_j.py
View file @
f89d18ff
...
...
@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
GPTJAttention
(
nn
.
Module
):
...
...
@@ -177,14 +178,13 @@ class GPTJBlock(nn.Module):
@
support_torch_compile
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTJConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
embed_dim
=
config
.
n_embd
self
.
wte
=
VocabParallelEmbedding
(
...
...
@@ -236,12 +236,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
assert
not
config
.
tie_word_embeddings
self
.
transformer
=
GPTJModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
GPTJModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
n_embd
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
f89d18ff
...
...
@@ -41,7 +41,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
GPTNeoXAttention
(
nn
.
Module
):
...
...
@@ -189,14 +190,13 @@ class GPTNeoXLayer(nn.Module):
@
support_torch_compile
class
GPTNeoXModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTNeoXConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
...
...
@@ -249,11 +249,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
gpt_neox
=
GPTNeoXModel
(
config
,
cache_config
,
quant_config
)
self
.
gpt_neox
=
GPTNeoXModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"gpt_neox"
))
self
.
embed_out
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/granite.py
View file @
f89d18ff
...
...
@@ -28,7 +28,7 @@ from transformers import GraniteConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -52,7 +52,8 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
)
class
GraniteMLP
(
nn
.
Module
):
...
...
@@ -257,15 +258,14 @@ class GraniteDecoderLayer(nn.Module):
@
support_torch_compile
class
GraniteModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
...
...
@@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
GraniteModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
model
=
GraniteModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
vllm/model_executor/models/granitemoe.py
View file @
f89d18ff
...
...
@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors
from
.
import
mixtral
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
make_layers
from
.utils
import
make_layers
,
maybe_prefix
class
GraniteMoeMoE
(
nn
.
Module
):
...
...
@@ -247,15 +247,14 @@ class GraniteMoeDecoderLayer(nn.Module):
@
support_torch_compile
class
GraniteMoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GraniteMoeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
...
...
@@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
GraniteMoeModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
model
=
GraniteMoeModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/idefics3.py
View file @
f89d18ff
...
...
@@ -22,17 +22,15 @@ import torch.utils.checkpoint
from
PIL
import
Image
from
torch
import
nn
# Temporary solution for transformers below 4.46.0.
from
transformers
import
PretrainedConfig
as
Idefics3Config
from
transformers
import
ProcessorMixin
as
Idefics3ImageProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -48,7 +46,8 @@ from .idefics2_vision_model import (
# yapf: enable
from
.interfaces
import
SupportsMultiModal
from
.llama
import
LlamaModel
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
merge_multimodal_embeddings
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
...
...
@@ -417,13 +416,13 @@ class Idefics3Connector(nn.Module):
class
Idefics3Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Idefics3Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
padding_idx
=
self
.
config
.
text_config
.
pad_token_id
self
.
vocab_size
=
self
.
config
.
text_config
.
vocab_size
...
...
@@ -613,22 +612,18 @@ class Idefics3Model(nn.Module):
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_idefics3
)
class
Idefics3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
model
=
Idefics3Model
(
config
,
cache_config
,
quant_config
)
self
.
model
=
Idefics3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
image_token_id
=
self
.
config
.
image_token_id
self
.
lm_head
=
ParallelLMHead
(
...
...
vllm/model_executor/models/internlm2.py
View file @
f89d18ff
...
...
@@ -250,14 +250,13 @@ class InternLMDecoderLayer(nn.Module):
@
support_torch_compile
class
InternLM2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
...
...
@@ -317,20 +316,13 @@ class InternLM2Model(nn.Module):
class
InternLM2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
InternLM2Model
(
config
,
cache_config
,
quant_config
,
self
.
model
=
InternLM2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
vllm/model_executor/models/internlm2_ve.py
View file @
f89d18ff
...
...
@@ -104,14 +104,13 @@ class InternLM2VEDecoderLayer(nn.Module):
class
InternLM2VEModel
(
InternLM2Model
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
cache_config
,
quant_config
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
InternLM2VEDecoderLayer
(
...
...
@@ -159,12 +158,8 @@ class InternLM2VEModel(InternLM2Model):
class
InternLM2VEForCausalLM
(
InternLM2ForCausalLM
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
vllm_config
,
prefix
=
prefix
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
...
...
vllm/model_executor/models/internvl.py
View file @
f89d18ff
...
...
@@ -35,7 +35,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_END
=
'</img>'
...
...
@@ -435,13 +435,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
config
,
quant_config
=
quant_config
,
is_mono
=
self
.
is_mono
,
prefix
=
"vision_model"
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
)
,
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
)
...
...
vllm/model_executor/models/jais.py
View file @
f89d18ff
...
...
@@ -44,7 +44,8 @@ from vllm.transformers_utils.configs import JAISConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -215,14 +216,13 @@ class JAISBlock(nn.Module):
@
support_torch_compile
class
JAISModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JAISConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
assert
not
config
.
add_cross_attention
assert
not
config
.
scale_attn_by_inverse_layer_idx
...
...
@@ -293,11 +293,12 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
transformer
=
JAISModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
JAISModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"transformer"
))
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
transformer
.
wte
else
:
...
...
vllm/model_executor/models/jamba.py
View file @
f89d18ff
...
...
@@ -7,7 +7,7 @@ from transformers import JambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -29,6 +29,7 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size
)
from
.interfaces
import
HasInnerState
,
SupportsLoRA
from
.utils
import
maybe_prefix
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -258,14 +259,14 @@ ALL_DECODER_LAYER_TYPES = {
class
JambaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
...
...
@@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
scheduler_config
=
vllm_config
.
scheduler_config
assert
not
cache_config
.
enable_prefix_caching
,
\
...
...
@@ -364,10 +360,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
super
().
__init__
()
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
JambaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
self
.
model
=
JambaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/llama.py
View file @
f89d18ff
...
...
@@ -28,7 +28,7 @@ from transformers import LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
...
@@ -271,15 +271,14 @@ class LlamaDecoderLayer(nn.Module):
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
...
...
@@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"norm"
:
"model.norm"
}
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
,
lora_config
=
lora_config
,
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
...
...
@@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules
=
[]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
,
lora_config
,
self
.
model
=
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
...
...
vllm/model_executor/models/llava.py
View file @
f89d18ff
...
...
@@ -32,7 +32,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
class
LlavaImagePixelInputs
(
TypedDict
):
...
...
@@ -282,7 +282,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
)
)
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
...
...
@@ -291,7 +291,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_next.py
View file @
f89d18ff
...
...
@@ -31,7 +31,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
init_vllm_registered_model
)
init_vllm_registered_model
,
maybe_prefix
)
class
LlavaNextImagePixelInputs
(
TypedDict
):
...
...
@@ -296,7 +296,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
)
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
self
.
multi_modal_projector
=
LlavaMultiModalProjector
(
...
...
@@ -307,7 +307,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
...
...
vllm/model_executor/models/llava_next_video.py
View file @
f89d18ff
...
...
@@ -29,7 +29,7 @@ from .llava import init_vision_tower_for_llava
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
# For profile run
_MAX_FRAMES_PER_VIDEO
=
32
...
...
@@ -267,7 +267,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
)
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
...
...
@@ -276,7 +276,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
...
...
vllm/model_executor/models/llava_onevision.py
View file @
f89d18ff
...
...
@@ -35,7 +35,7 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
...
...
@@ -418,12 +418,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
config
,
quant_config
,
require_post_norm
=
False
,
prefix
=
"vision_tower"
)
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
)
)
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
"language_model"
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
...
...
vllm/model_executor/models/mamba.py
View file @
f89d18ff
...
...
@@ -6,7 +6,7 @@ from torch import nn
from
transformers
import
MambaConfig
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -26,6 +26,8 @@ from vllm.sequence import IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
from
.utils
import
maybe_prefix
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -73,14 +75,14 @@ class MambaDecoderLayer(nn.Module):
class
MambaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
...
...
@@ -130,14 +132,9 @@ class MambaModel(nn.Module):
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
scheduler_config
=
vllm_config
.
scheduler_config
assert
not
cache_config
.
enable_prefix_caching
,
\
...
...
@@ -146,10 +143,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
super
().
__init__
()
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
backbone
=
MambaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
lora_config
=
lora_config
)
self
.
backbone
=
MambaModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"backbone"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/minicpm.py
View file @
f89d18ff
...
...
@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
...
...
@@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
MiniCPMMoE
(
nn
.
Module
):
...
...
@@ -351,15 +352,14 @@ class MiniCPMDecoderLayer(nn.Module):
@
support_torch_compile
class
MiniCPMModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
...
...
@@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
prefix
=
prefix
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
_init_model
()
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
@@ -502,11 +500,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
):
self
.
model
=
MiniCPMModel
(
config
=
self
.
config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
lora_config
=
self
.
lora_config
)
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
.
model
=
MiniCPMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
def
forward
(
self
,
...
...
vllm/model_executor/models/minicpm3.py
View file @
f89d18ff
...
...
@@ -28,7 +28,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM
,
MiniCPMModel
)
from
.utils
import
make_layers
from
.utils
import
make_layers
,
maybe_prefix
class
MiniCPM3Attention
(
nn
.
Module
):
...
...
@@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM
def
_init_model
(
self
):
self
.
model
=
MiniCPM3Model
(
config
=
self
.
config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
lora_config
=
self
.
lora_config
)
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
self
.
model
=
MiniCPM3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
vllm/model_executor/models/minicpmv.py
View file @
f89d18ff
...
...
@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
is_pp_missing_parameter
from
.utils
import
is_pp_missing_parameter
,
maybe_prefix
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
...
...
@@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
):
config
=
vllm_config
.
model_config
.
hf_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
super
().
__init__
()
# All MiniCPM-V models disable `tie_word_embeddings` but
...
...
@@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
multimodal_config
=
multimodal_config
self
.
version
=
get_version_by_config
(
self
.
config
)
self
.
llm
=
self
.
init_llm
(
config
,
cache_config
,
quant_
config
,
prefix
=
"llm"
)
self
.
vpm
=
self
.
init_vision_module
(
config
,
quant_config
,
prefix
=
"vpm"
)
self
.
llm
=
self
.
init_llm
(
vllm_config
=
vllm_
config
,
prefix
=
maybe_prefix
(
prefix
,
"llm"
))
self
.
vpm
=
self
.
init_vision_module
(
config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vpm"
)
)
param_dtype
=
torch
.
get_default_dtype
()
self
.
vpm
.
to
(
dtype
=
param_dtype
)
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
...
...
@@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
,
quant_config
=
quant_config
,
prefix
=
"resampler"
)
prefix
=
maybe_prefix
(
prefix
,
"resampler"
))
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
"llm.lm_head"
)
prefix
=
maybe_prefix
(
prefix
,
"llm.lm_head"
))
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
...
...
@@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
raise
NotImplementedError
...
...
@@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
return
LLMWrapper
(
MiniCPMModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
...
...
@@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
return
LLMWrapper
(
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
...
...
@@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def
init_llm
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
nn
.
Module
:
return
LLMWrapper
(
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
return
LLMWrapper
(
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
name
=
"model"
)
def
init_vision_module
(
...
...
@@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if
instance_class
is
None
:
raise
ValueError
(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6"
)
return
instance_class
(
vllm_config
,
prefix
=
prefix
)
return
instance_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment