Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13370712
Unverified
Commit
13370712
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 01, 2024
Browse files
[Model] Replace embedding models with pooling adapter (#10769)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7e4bbda5
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
125 additions
and
96 deletions
+125
-96
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+3
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+3
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+14
-25
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+3
-2
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+12
-16
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-16
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+41
-18
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+3
-2
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+19
-5
vllm/multimodal/base.py
vllm/multimodal/base.py
+3
-3
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+3
-2
vllm/utils.py
vllm/utils.py
+19
-3
No files found.
vllm/model_executor/models/llava_onevision.py
View file @
13370712
...
...
@@ -422,9 +422,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
))
self
.
multi_modal_projector
=
LlavaOnevisionMultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
image_newline
=
nn
.
Parameter
(
torch
.
empty
(
config
.
text_config
.
hidden_size
))
...
...
vllm/model_executor/models/paligemma.py
View file @
13370712
...
...
@@ -151,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
quant_config
=
quant_config
config
.
text_config
.
architectures
=
[
"GemmaForCausalLM"
]
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
...
...
vllm/model_executor/models/phi3v.py
View file @
13370712
...
...
@@ -29,24 +29,22 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
cached_get_tokenizer
,
repeat_and_pad_token
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
...
...
@@ -536,7 +534,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
@@ -556,18 +553,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"model.vision_embed_tokens"
))
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
self
.
language_model
=
LlamaForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
""
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
prefix
=
""
,
# We don't directly initialize vLLM's LlamaForCausalLM so we
# can automatically apply embedding wrapper if this model is
# initialized as an embedding model
architectures
=
[
"LlamaForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -739,13 +735,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
hf_to_vllm_mapper
=
WeightsMapper
(
...
...
vllm/model_executor/models/pixtral.py
View file @
13370712
...
...
@@ -172,9 +172,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
...
...
vllm/model_executor/models/qwen2.py
View file @
13370712
...
...
@@ -31,6 +31,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -55,6 +56,8 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
class
Qwen2MLP
(
nn
.
Module
):
...
...
@@ -433,7 +436,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
config
=
config
self
.
lora_config
=
lora_config
...
...
@@ -454,14 +456,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -499,13 +493,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
...
...
@@ -553,6 +540,15 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM),
# after changing the default pooling method
if
pooler_config
.
pooling_type
is
None
:
logger
.
warning
(
"This embedding model will default to last-token pooling in "
"an upcoming version. To avoid breaking changes, you should "
"pass `--override-pooler-config '{
\"
pooling_type
\"
:
\"
MEAN
\"
}'`"
" explicitly."
)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
MEAN
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
13370712
...
...
@@ -50,7 +50,6 @@ from vllm.model_executor.layers.activation import QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
...
...
@@ -59,14 +58,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.inputs
import
(
MultiModalData
,
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
cached_get_processor
...
...
@@ -1070,7 +1068,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
assert
not
cache_config
.
enable_prefix_caching
,
\
"Qwen2-VL currently does not support prefix caching"
...
...
@@ -1102,11 +1099,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
softmax
=
False
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
...
...
@@ -1361,13 +1354,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
...
...
vllm/model_executor/models/registry.py
View file @
13370712
...
...
@@ -20,6 +20,7 @@ import torch.nn as nn
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.adapters
import
as_embedding_model
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
)
...
...
@@ -107,15 +108,15 @@ _EMBEDDING_MODELS = {
"RobertaForMaskedLM"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"XLMRobertaModel"
:
(
"roberta"
,
"RobertaEmbeddingModel"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2
EmbeddingModel
"
),
"Gemma2Model"
:
(
"gemma2"
,
"Gemma2
ForCausalLM
"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"LlamaModel"
:
(
"llama"
,
"Llama
EmbeddingModel
"
),
"LlamaModel"
:
(
"llama"
,
"Llama
ForCausalLM
"
),
**
{
# Multiple models share the same architecture, so we include them all
k
:
(
mod
,
arch
)
for
k
,
(
mod
,
arch
)
in
_TEXT_GENERATION_MODELS
.
items
()
if
arch
==
"LlamaForCausalLM"
},
"MistralModel"
:
(
"llama"
,
"Llama
EmbeddingModel
"
),
"MistralModel"
:
(
"llama"
,
"Llama
ForCausalLM
"
),
"Phi3ForCausalLM"
:
(
"phi3"
,
"Phi3ForCausalLM"
),
"Qwen2Model"
:
(
"qwen2"
,
"Qwen2EmbeddingModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
...
...
@@ -125,7 +126,7 @@ _EMBEDDING_MODELS = {
# [Multimodal]
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
)
# noqa: E501
,
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
)
,
# noqa: E501
}
_CROSS_ENCODER_MODELS
=
{
...
...
@@ -208,6 +209,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
@
dataclass
(
frozen
=
True
)
class
_ModelInfo
:
architecture
:
str
is_text_generation_model
:
bool
is_embedding_model
:
bool
supports_cross_encoding
:
bool
...
...
@@ -218,9 +220,19 @@ class _ModelInfo:
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
is_embedding_model_
=
is_embedding_model
(
model
)
if
not
is_embedding_model_
:
try
:
as_embedding_model
(
model
)
except
Exception
:
pass
else
:
is_embedding_model_
=
True
return
_ModelInfo
(
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_embedding_model
=
is_embedding_model
(
model
)
,
is_embedding_model
=
is_embedding_model
_
,
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_pp
=
supports_pp
(
model
),
...
...
@@ -399,13 +411,13 @@ class _ModelRegistry:
def
inspect_model_cls
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
_ModelInfo
:
)
->
Tuple
[
_ModelInfo
,
str
]
:
architectures
=
self
.
_normalize_archs
(
architectures
)
for
arch
in
architectures
:
model_info
=
self
.
_try_inspect_model_cls
(
arch
)
if
model_info
is
not
None
:
return
model_info
return
(
model_info
,
arch
)
return
self
.
_raise_for_unsupported
(
architectures
)
...
...
@@ -426,39 +438,50 @@ class _ModelRegistry:
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
is_text_generation_model
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_text_generation_model
def
is_embedding_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
is_embedding_model
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_embedding_model
def
is_cross_encoder_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_cross_encoding
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_cross_encoding
def
is_multimodal_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_multimodal
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal
def
is_pp_supported_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
supports_pp
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_pp
def
model_has_inner_state
(
self
,
architectures
:
Union
[
str
,
List
[
str
]])
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
has_inner_state
def
model_has_inner_state
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
has_inner_state
def
is_attention_free_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]])
->
bool
:
return
self
.
inspect_model_cls
(
architectures
).
is_attention_free
def
is_attention_free_model
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
is_attention_free
ModelRegistry
=
_ModelRegistry
({
...
...
vllm/model_executor/models/ultravox.py
View file @
13370712
...
...
@@ -360,9 +360,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
config
.
text_model_id
is
not
None
:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
...
...
vllm/model_executor/models/utils.py
View file @
13370712
...
...
@@ -173,8 +173,15 @@ class AutoWeightsLoader:
module_load_weights
=
getattr
(
module
,
"load_weights"
,
None
)
if
callable
(
module_load_weights
):
loaded_params
=
module_load_weights
(
weights
)
yield
from
map
(
lambda
x
:
self
.
_get_qualname
(
base_prefix
,
x
),
loaded_params
)
if
loaded_params
is
None
:
logger
.
warning
(
"Unable to collect loaded parameters "
"for module %s"
,
module
)
else
:
yield
from
map
(
lambda
x
:
self
.
_get_qualname
(
base_prefix
,
x
),
loaded_params
,
)
child_modules
=
dict
(
module
.
named_children
())
child_params
=
dict
(
module
.
named_parameters
(
recurse
=
False
))
...
...
@@ -232,17 +239,24 @@ class AutoWeightsLoader:
def
init_vllm_registered_model
(
hf_config
:
PretrainedConfig
,
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
hf_config
:
Optional
[
PretrainedConfig
]
=
None
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
nn
.
Module
:
"""
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
"""
from
vllm.model_executor.model_loader.loader
import
_initialize_model
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
)
return
_initialize_model
(
vllm_config
,
prefix
)
if
hf_config
is
not
None
:
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
)
return
_initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
architectures
=
architectures
)
@
overload
...
...
vllm/multimodal/base.py
View file @
13370712
...
...
@@ -7,7 +7,7 @@ from torch import nn
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_allowed_kwarg_only_overrides
,
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
if
TYPE_CHECKING
:
...
...
@@ -54,8 +54,8 @@ class MultiModalPlugin(ABC):
"""
def
__init__
(
self
)
->
None
:
self
.
_input_mappers
:
Dict
[
Type
[
nn
.
Module
]
,
MultiModalInputMapper
]
=
{}
self
.
_max_mm_tokens
:
Dict
[
Type
[
nn
.
Module
]
,
MultiModalTokensCalc
]
=
{}
self
.
_input_mappers
=
ClassRegistry
[
nn
.
Module
,
MultiModalInputMapper
]
()
self
.
_max_mm_tokens
=
ClassRegistry
[
nn
.
Module
,
MultiModalTokensCalc
]
()
@
abstractmethod
def
get_data_key
(
self
)
->
str
:
...
...
vllm/multimodal/registry.py
View file @
13370712
...
...
@@ -9,6 +9,7 @@ from typing_extensions import TypeAlias
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
ClassRegistry
from
.audio
import
AudioPlugin
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
...
...
@@ -62,8 +63,8 @@ class MultiModalRegistry:
plugins
:
Sequence
[
MultiModalPlugin
]
=
DEFAULT_PLUGINS
)
->
None
:
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_processor_factories
:
Dict
[
Type
[
nn
.
Module
]
,
MultiModalProcessorFactory
]
=
{}
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
MultiModalProcessorFactory
]
()
# This is used for non-multimodal models
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
...
...
vllm/utils.py
View file @
13370712
...
...
@@ -20,7 +20,7 @@ import uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
collections
import
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
...
...
@@ -1517,13 +1517,13 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
,
Generic
[
T
]):
class
LazyDict
(
Mapping
[
str
,
T
]
,
Generic
[
T
]):
def
__init__
(
self
,
factory
:
Dict
[
str
,
Callable
[[],
T
]]):
self
.
_factory
=
factory
self
.
_dict
:
Dict
[
str
,
T
]
=
{}
def
__getitem__
(
self
,
key
)
->
T
:
def
__getitem__
(
self
,
key
:
str
)
->
T
:
if
key
not
in
self
.
_dict
:
if
key
not
in
self
.
_factory
:
raise
KeyError
(
key
)
...
...
@@ -1540,6 +1540,22 @@ class LazyDict(Mapping, Generic[T]):
return
len
(
self
.
_factory
)
class
ClassRegistry
(
UserDict
[
type
[
T
],
_V
]):
def
__getitem__
(
self
,
key
:
type
[
T
])
->
_V
:
for
cls
in
key
.
mro
():
if
cls
in
self
.
data
:
return
self
.
data
[
cls
]
raise
KeyError
(
key
)
def
__contains__
(
self
,
key
:
object
)
->
bool
:
if
not
isinstance
(
key
,
type
):
return
False
return
any
(
cls
in
self
.
data
for
cls
in
key
.
mro
())
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Create a weak reference to a tensor.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment