Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13370712
Unverified
Commit
13370712
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7e4bbda5
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
125 additions
and
96 deletions
+125
-96
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+3
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+3
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+14
-25
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+3
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+12
-16
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-16
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+41
-18
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+3
-2
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+19
-5
vllm/multimodal/base.py
vllm/multimodal/base.py
+3
-3
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+3
-2
vllm/utils.py
vllm/utils.py
+19
-3
No files found.
vllm/model_executor/models/llava_onevision.py
View file @
13370712
...
@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
))
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
))
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
image_newline
=
nn
.
Parameter
(
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
torch
.
empty
(
config
.
text_config
.
hidden_size
))
...
...
vllm/model_executor/models/paligemma.py
View file @
13370712
...
@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
config
.
text_config
.
architectures
=
[
"GemmaForCausalLM"
]
config
.
text_config
.
architectures
=
[
"GemmaForCausalLM"
]
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
...
...
vllm/model_executor/models/phi3v.py
View file @
13370712
...
@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
...
@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
cached_get_tokenizer
,
repeat_and_pad_token
from
vllm.multimodal.utils
import
cached_get_tokenizer
,
repeat_and_pad_token
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
...
@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model.vision_embed_tokens"
))
prefix
=
maybe_prefix
(
prefix
,
"model.vision_embed_tokens"
))
# The prefix is empty intentionally because default prefix of
self
.
language_model
=
init_vllm_registered_model
(
# LlamaForCausalLM is "model"
vllm_config
=
vllm_config
,
self
.
language_model
=
LlamaForCausalLM
(
vllm_config
=
vllm_config
,
# The prefix is empty intentionally because default prefix of
prefix
=
""
)
# LlamaForCausalLM is "model"
prefix
=
""
,
# The same model class supports both language generation and embedding
# We don't directly initialize vLLM's LlamaForCausalLM so we
# because the architecture name is the same
# can automatically apply embedding wrapper if this model is
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
# initialized as an embedding model
pooler_config
,
architectures
=
[
"LlamaForCausalLM"
],
pooling_type
=
PoolingType
.
LAST
,
)
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
...
@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
hf_to_vllm_mapper
=
WeightsMapper
(
hf_to_vllm_mapper
=
WeightsMapper
(
...
...
vllm/model_executor/models/pixtral.py
View file @
13370712
...
@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM
# init MistralForCausalLM
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
...
...
vllm/model_executor/models/qwen2.py
View file @
13370712
...
@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
...
@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
...
@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
logger
=
init_logger
(
__name__
)
class
Qwen2MLP
(
nn
.
Module
):
class
Qwen2MLP
(
nn
.
Module
):
...
@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
...
@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
loader
=
AutoWeightsLoader
(
...
@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
if
pooler_config
.
pooling_type
is
None
:
logger
.
warning
(
"This embedding model will default to last-token pooling in "
"an upcoming version. To avoid breaking changes, you should "
"pass `--override-pooler-config '{
\"
pooling_type
\"
:
\"
MEAN
\"
}'`"
" explicitly."
)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
MEAN
,
pooling_type
=
PoolingType
.
MEAN
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
13370712
...
@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
...
@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
...
@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...
@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.inputs
import
(
MultiModalData
,
MultiModalDataDict
,
from
vllm.multimodal.inputs
import
(
MultiModalData
,
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
)
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.processor
import
cached_get_processor
...
@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
assert
not
cache_config
.
enable_prefix_caching
,
\
assert
not
cache_config
.
enable_prefix_caching
,
\
"Qwen2-VL currently does not support prefix caching"
"Qwen2-VL currently does not support prefix caching"
...
@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
...
@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
vllm/model_executor/models/registry.py
View file @
13370712
...
@@ -20,6 +20,7 @@ import torch.nn as nn
...
@@ -20,6 +20,7 @@ import torch.nn as nn
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.adapters
import
as_embedding_model
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
supports_cross_encoding
,
supports_multimodal
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
supports_pp
)
...
@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
...
@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"RobertaForMaskedLM"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"XLMRobertaModel"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"XLMRobertaModel"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2
EmbeddingModel
"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2
ForCausalLM
"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"LlamaModel"
:
(
"llama"
,
"Llama
EmbeddingModel
"
),
"LlamaModel"
:
(
"llama"
,
"Llama
ForCausalLM
"
),
**
{
**
{
# Multiple models share the same architecture, so we include them all
# Multiple models share the same architecture, so we include them all
k
:
(
mod
,
arch
)
for
k
,
(
mod
,
arch
)
in
_TEXT_GENERATION_MODELS
.
items
()
k
:
(
mod
,
arch
)
for
k
,
(
mod
,
arch
)
in
_TEXT_GENERATION_MODELS
.
items
()
if
arch
==
"LlamaForCausalLM"
if
arch
==
"LlamaForCausalLM"
},
},
"MistralModel"
:
(
"llama"
,
"Llama
EmbeddingModel
"
),
"MistralModel"
:
(
"llama"
,
"Llama
ForCausalLM
"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
...
@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
...
@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
# [Multimodal]
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
)
# noqa: E501
,
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
)
,
# noqa: E501
}
}
_CROSS_ENCODER_MODELS
=
{
_CROSS_ENCODER_MODELS
=
{
...
@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
...
@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
_ModelInfo
:
class
_ModelInfo
:
architecture
:
str
is_text_generation_model
:
bool
is_text_generation_model
:
bool
is_embedding_model
:
bool
is_embedding_model
:
bool
supports_cross_encoding
:
bool
supports_cross_encoding
:
bool
...
@@ -218,9 +220,19 @@ class _ModelInfo:
...
@@ -218,9 +220,19 @@ class _ModelInfo:
@
staticmethod
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
is_embedding_model_
=
is_embedding_model
(
model
)
if
not
is_embedding_model_
:
try
:
as_embedding_model
(
model
)
except
Exception
:
pass
else
:
is_embedding_model_
=
True
return
_ModelInfo
(
return
_ModelInfo
(
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_text_generation_model
=
is_text_generation_model
(
model
),
is_embedding_model
=
is_embedding_model
(
model
)
,
is_embedding_model
=
is_embedding_model
_
,
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_pp
=
supports_pp
(
model
),
supports_pp
=
supports_pp
(
model
),
...
@@ -399,13 +411,13 @@ class _ModelRegistry:
...
@@ -399,13 +411,13 @@ class _ModelRegistry:
def
inspect_model_cls
(
def
inspect_model_cls
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
_ModelInfo
:
)
->
Tuple
[
_ModelInfo
,
str
]
:
architectures
=
self
.
_normalize_archs
(
architectures
)
architectures
=
self
.
_normalize_archs
(
architectures
)
for
arch
in
architectures
:
for
arch
in
architectures
:
model_info
=
self
.
_try_inspect_model_cls
(
arch
)
model_info
=
self
.
_try_inspect_model_cls
(
arch
)
if
model_info
is
not
None
:
if
model_info
is
not
None
:
return
model_info
return
(
model_info
,
arch
)
return
self
.
_raise_for_unsupported
(
architectures
)
return
self
.
_raise_for_unsupported
(
architectures
)
...
@@ -426,39 +438,50 @@ class _ModelRegistry:
...
@@ -426,39 +438,50 @@ class _ModelRegistry:
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
is_text_generation_model
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_text_generation_model
def
is_embedding_model
(
def
is_embedding_model
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
is_embedding_model
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_embedding_model
def
is_cross_encoder_model
(
def
is_cross_encoder_model
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_cross_encoding
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_cross_encoding
def
is_multimodal_model
(
def
is_multimodal_model
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_multimodal
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal
def
is_pp_supported_model
(
def
is_pp_supported_model
(
self
,
self
,
architectures
:
Union
[
str
,
List
[
str
]],
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_pp
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_pp
def
model_has_inner_state
(
self
,
architectures
:
Union
[
str
,
def
model_has_inner_state
(
List
[
str
]])
->
bool
:
self
,
return
self
.
inspect_model_cls
(
architectures
).
has_inner_state
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
has_inner_state
def
is_attention_free_model
(
self
,
architectures
:
Union
[
str
,
def
is_attention_free_model
(
List
[
str
]])
->
bool
:
self
,
return
self
.
inspect_model_cls
(
architectures
).
is_attention_free
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_attention_free
ModelRegistry
=
_ModelRegistry
({
ModelRegistry
=
_ModelRegistry
({
...
...
vllm/model_executor/models/ultravox.py
View file @
13370712
...
@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
config
.
text_model_id
is
not
None
:
if
config
.
text_model_id
is
not
None
:
# this prefix is not for initialization, but for loading weights
# this prefix is not for initialization, but for loading weights
# note the trailing dot
# note the trailing dot
...
...
vllm/model_executor/models/utils.py
View file @
13370712
...
@@ -173,8 +173,15 @@ class AutoWeightsLoader:
...
@@ -173,8 +173,15 @@ class AutoWeightsLoader:
module_load_weights
=
getattr
(
module
,
"load_weights"
,
None
)
module_load_weights
=
getattr
(
module
,
"load_weights"
,
None
)
if
callable
(
module_load_weights
):
if
callable
(
module_load_weights
):
loaded_params
=
module_load_weights
(
weights
)
loaded_params
=
module_load_weights
(
weights
)
yield
from
map
(
lambda
x
:
self
.
_get_qualname
(
base_prefix
,
x
),
if
loaded_params
is
None
:
loaded_params
)
logger
.
warning
(
"Unable to collect loaded parameters "
"for module %s"
,
module
)
else
:
yield
from
map
(
lambda
x
:
self
.
_get_qualname
(
base_prefix
,
x
),
loaded_params
,
)
child_modules
=
dict
(
module
.
named_children
())
child_modules
=
dict
(
module
.
named_children
())
child_params
=
dict
(
module
.
named_parameters
(
recurse
=
False
))
child_params
=
dict
(
module
.
named_parameters
(
recurse
=
False
))
...
@@ -232,17 +239,24 @@ class AutoWeightsLoader:
...
@@ -232,17 +239,24 @@ class AutoWeightsLoader:
def
init_vllm_registered_model
(
def
init_vllm_registered_model
(
hf_config
:
PretrainedConfig
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
hf_config
:
Optional
[
PretrainedConfig
]
=
None
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""
"""
Helper function to initialize an inner model registered to vLLM,
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
based on the arguments passed to the outer vLLM model.
"""
"""
from
vllm.model_executor.model_loader.loader
import
_initialize_model
from
vllm.model_executor.model_loader.loader
import
_initialize_model
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
)
return
_initialize_model
(
vllm_config
,
prefix
)
if
hf_config
is
not
None
:
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
)
return
_initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
architectures
=
architectures
)
@
overload
@
overload
...
...
vllm/multimodal/base.py
View file @
13370712
...
@@ -7,7 +7,7 @@ from torch import nn
...
@@ -7,7 +7,7 @@ from torch import nn
from
vllm.inputs
import
InputContext
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
resolve_mm_processor_kwargs
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
...
@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_input_mappers
:
Dict
[
Type
[
nn
.
Module
]
,
MultiModalInputMapper
]
=
{}
self
.
_input_mappers
=
ClassRegistry
[
nn
.
Module
,
MultiModalInputMapper
]
()
self
.
_max_mm_tokens
:
Dict
[
Type
[
nn
.
Module
]
,
MultiModalTokensCalc
]
=
{}
self
.
_max_mm_tokens
=
ClassRegistry
[
nn
.
Module
,
MultiModalTokensCalc
]
()
@
abstractmethod
@
abstractmethod
def
get_data_key
(
self
)
->
str
:
def
get_data_key
(
self
)
->
str
:
...
...
vllm/multimodal/registry.py
View file @
13370712
...
@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
...
@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
ClassRegistry
from
.audio
import
AudioPlugin
from
.audio
import
AudioPlugin
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
...
@@ -62,8 +63,8 @@ class MultiModalRegistry:
...
@@ -62,8 +63,8 @@ class MultiModalRegistry:
plugins
:
Sequence
[
MultiModalPlugin
]
=
DEFAULT_PLUGINS
)
->
None
:
plugins
:
Sequence
[
MultiModalPlugin
]
=
DEFAULT_PLUGINS
)
->
None
:
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_processor_factories
:
Dict
[
Type
[
nn
.
Module
]
,
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
MultiModalProcessorFactory
]
=
{}
MultiModalProcessorFactory
]
()
# This is used for non-multimodal models
# This is used for non-multimodal models
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
...
...
vllm/utils.py
View file @
13370712
...
@@ -20,7 +20,7 @@ import uuid
...
@@ -20,7 +20,7 @@ import uuid
import
warnings
import
warnings
import
weakref
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
collections
import
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
...
@@ -1517,13 +1517,13 @@ class AtomicCounter:
...
@@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
,
Generic
[
T
]):
class
LazyDict
(
Mapping
[
str
,
T
]
,
Generic
[
T
]):
def
__init__
(
self
,
factory
:
Dict
[
str
,
Callable
[[],
T
]]):
def
__init__
(
self
,
factory
:
Dict
[
str
,
Callable
[[],
T
]]):
self
.
_factory
=
factory
self
.
_factory
=
factory
self
.
_dict
:
Dict
[
str
,
T
]
=
{}
self
.
_dict
:
Dict
[
str
,
T
]
=
{}
def
__getitem__
(
self
,
key
)
->
T
:
def
__getitem__
(
self
,
key
:
str
)
->
T
:
if
key
not
in
self
.
_dict
:
if
key
not
in
self
.
_dict
:
if
key
not
in
self
.
_factory
:
if
key
not
in
self
.
_factory
:
raise
KeyError
(
key
)
raise
KeyError
(
key
)
...
@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
...
@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return
len
(
self
.
_factory
)
return
len
(
self
.
_factory
)
class
ClassRegistry
(
UserDict
[
type
[
T
],
_V
]):
def
__getitem__
(
self
,
key
:
type
[
T
])
->
_V
:
for
cls
in
key
.
mro
():
if
cls
in
self
.
data
:
return
self
.
data
[
cls
]
raise
KeyError
(
key
)
def
__contains__
(
self
,
key
:
object
)
->
bool
:
if
not
isinstance
(
key
,
type
):
return
False
return
any
(
cls
in
self
.
data
for
cls
in
key
.
mro
())
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Create a weak reference to a tensor.
Create a weak reference to a tensor.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment